gpt4 book ai didi

python - 带有许多字符串标签的 Matplotlib matshow

转载 作者:太空狗 更新时间:2023-10-29 18:30:23 24 4
gpt4 key购买 nike

今天我尝试从我的分类模型中绘制混淆矩阵。

在一些页面中搜索后,我发现 pyplot 中的 matshow 可以帮助我。

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title(title)
fig.colorbar(cax)
if labels:
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

如果我的标签很少,它会很好用

y_true = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'c', 'd', 'b', 'a', 'b', 'a']
y_pred = ['a', 'b', 'c', 'd', 'a', 'b', 'b', 'a', 'c', 'a', 'a', 'a', 'a', 'a']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)

enter image description here

但是如果我有很多标签,有些标签不会正确显示

y_true = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
y_pred = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)

enter image description here

我的问题是如何在 matshow 图中显示所有标签?我尝试了类似 fontdict 的方法,但它仍然无法正常工作

最佳答案

您可以使用 matplotlib.ticker 控制报价的频率模块。

在这种情况下,您希望每 1 的倍数设置一个刻度,因此我们可以使用 MultipleLocator

在调用 plt.show() 之前添加这两行:

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

它会为您的 y_truey_pred 中的每个字母生成一个勾号和标签。

我还更改了您的 matshow 调用以使用您在函数调用中指定的颜色图:

cax = ax.matshow(cm,cmap=cmap)

enter image description here

为了完整起见,您的整个函数将如下所示:

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import matplotlib.ticker as ticker

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
fig = plt.figure()
ax = fig.add_subplot(111)

# I also added cmap=cmap here, to make use of the
# colormap you specify in the function call
cax = ax.matshow(cm,cmap=cmap)
plt.title(title)
fig.colorbar(cax)
if labels:
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('confusionmatrix.png')

关于python - 带有许多字符串标签的 Matplotlib matshow,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34781096/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com