gpt4 book ai didi

python - 使用 legend_elements 设置散点图图例标签

转载 作者:行者123 更新时间:2023-12-01 07:24:24 25 4
gpt4 key购买 nike

我刚刚将 matplotlib 升级到版本 3.1.1,并且正在尝试使用 legend_elements。

我正在 30,000 个扁平灰度图像的数据集上绘制 PCA 前两个分量的散点图。每张图像都标记为四个主类别之一(配饰、服装、鞋类、个人护理)。我通过创建一个值为 0 到 3 的颜色列,按“主类别”对绘图进行颜色编码。

我已阅读 PathCollection.legend_elements 的文档,但尚未成功合并“func”或“fmt”参数。 https://matplotlib.org/3.1.1/api/collections_api.html#matplotlib.collections.PathCollection.legend_elements

此外,我尝试遵循提供的示例: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/scatter_with_legend.html

### create column for color codes
masterCat_codes = {'Accessories':0,'Apparel':1, 'Footwear':2, 'Personal Care':3}
df['colors'] = df['masterCategory'].apply(lambda x: masterCat_codes[x])

### create scatter plot
fig, ax = plt.subplots(figsize=(8,8))
scatter = ax.scatter( *full_pca.T, s=.1 , c=df['colors'], label= df['masterCategory'], cmap='viridis')

### using legend_elements
legend1 = ax.legend(*scatter.legend_elements(num=[0,1,2,3]), loc="upper left", title="Category Codes")
ax.add_artist(legend1)
plt.show()

生成的图例标签为 0、1、2、3。(无论我在定义“分散”时是否指定 label = df['masterCategory'],都会发生这种情况)。我希望标签上写着配饰、服装、鞋类、个人护理。

有没有办法用 legend_elements 来完成这个任务?

注意:由于数据集很大并且预处理计算量很大,因此我编写了一个更容易重现的示例:

fake_data = np.array([[1,1],[1,2],[1,3],[2,1],[2,2],[2,3],[3,1],[3,2],[3,3]])
fake_df = pd.DataFrame(fake_data, columns=['X', 'Y'])
groups = np.array(['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'])
fake_df['Group'] = groups

group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots()
scatter = ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])
legend = ax.legend(*scatter.legend_elements(num=[0,1,2]), loc="upper left", title="Group \nCodes")
ax.add_artist(legend)
plt.show()

enter image description here

最佳答案

解决方案感谢欧内斯特的重要性

group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots(figsize=(8,8))
scatter = ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])

handles = scatter.legend_elements(num=[0,1,2,3])[0] # extract the handles from the existing scatter plot

ax.legend(title='Group\nCodes', handles=handles, labels=group_codes.keys())
plt.show()

enter image description here

关于python - 使用 legend_elements 设置散点图图例标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57532115/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com