Gráficos de dispersão em Pandas / Pyplot: como traçar por categoria

90

Estou tentando fazer um gráfico de dispersão simples em pyplot usando um objeto Pandas DataFrame, mas quero uma maneira eficiente de plotar duas variáveis, mas tem os símbolos ditados por uma terceira coluna (chave). Eu tentei várias maneiras usando df.groupby, mas não com sucesso. Um exemplo de script df está abaixo. Isso colore os marcadores de acordo com 'chave1', mas gostaria de ver uma legenda com as categorias 'chave1'. Eu estou perto? Obrigado.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()
user2989613
fonte

Respostas:

120

Você pode usar scatterpara isso, mas isso requer ter valores numéricos para o seu key1, e você não terá uma legenda, como você notou.

É melhor usar apenas plotpara categorias discretas como esta. Por exemplo:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

insira a descrição da imagem aqui

Se você quiser que as coisas pareçam com o pandasestilo padrão , basta atualizar o rcParamscom a folha de estilo do pandas e usar seu gerador de cores. (Também estou ajustando ligeiramente a legenda):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

insira a descrição da imagem aqui

Joe Kington
fonte
Por que no exemplo RGB acima o símbolo é mostrado duas vezes na legenda? Como mostrar apenas uma vez?
Steve Schulist de
1
@SteveSchulist - Use ax.legend(numpoints=1)para mostrar apenas um marcador. Existem dois, como Line2Dacontece com um , geralmente há uma linha conectando os dois marcadores.
Joe Kington
Este código só funcionou para mim após adicionar plt.hold(True)após o ax.plot()comando. Alguma ideia do porquê?
Yuval Atzmon
set_color_cycle() foi descontinuado no matplotlib 1.5. Existe set_prop_cycle(), agora.
ale
52

Isso é simples de fazer com Seaborn ( pip install seaborn) como um oneliner

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1") :

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.scatterplot(x="one", y="two", data=df, hue="key1")

insira a descrição da imagem aqui

Aqui está o dataframe para referência:

insira a descrição da imagem aqui

Como você tem três colunas de variáveis ​​em seus dados, pode desejar plotar todas as dimensões de pares com:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

insira a descrição da imagem aqui

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ é outra opção.

Bob Baxley
fonte
19

Com plt.scatter, só consigo pensar em um: para usar um artista proxy:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

E o resultado é:

insira a descrição da imagem aqui

CT Zhu
fonte
10

Você pode usar df.plot.scatter e passar uma matriz para o argumento c = que define a cor de cada ponto:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

insira a descrição da imagem aqui

Arjaan Buijk
fonte
4

Você também pode tentar o Altair ou o ggpot, que se concentram em visualizações declarativas.

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Código Altair

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

insira a descrição da imagem aqui

código ggplot

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

insira a descrição da imagem aqui

Nipun Batra
fonte
4

A partir do matplotlib 3.1 em diante, você pode usar .legend_elements(). Um exemplo é mostrado na criação de legenda automatizada . A vantagem é que uma única chamada de dispersão pode ser usada.

Nesse caso:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

insira a descrição da imagem aqui

Caso as chaves não tenham sido fornecidas diretamente como números, seria

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

insira a descrição da imagem aqui

ImportanceOfBeingErnest
fonte
Recebi um erro dizendo que o objeto 'PathCollection' não tem o atributo 'legends_elements'. Meu código é o seguinte. fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
Nandish Patel de
1
@NandishPatel Verifique a primeira frase desta resposta. Além disso, certifique-se de não confundir legends_elementse legend_elements.
ImportanceOfBeingErnest de
Sim obrigado. Isso foi um erro de digitação (legendas / legenda). Eu estava trabalhando em algo desde as últimas 6 horas, então a versão Matplotlib não me ocorreu. Achei que estava usando o mais recente. Fiquei confuso que a documentação diz que existe esse método, mas o código estava dando um erro. Obrigado novamente. Eu posso dormir agora
Nandish Patel de
1

seaborn tem uma função de invólucro scatterplotque o faz de forma mais eficiente.

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])
yosemite_k
fonte