gpt4 book ai didi

python - 如何平衡 numpy 数组中的类?

转载 作者:太空宇宙 更新时间:2023-11-04 00:35:27 24 4
gpt4 key购买 nike

我有 2 个 numpy 数组,如下所示:

images 包含图像文件的名称(images.shape 是 (N, 3, 128, 128)):
image_1.jpg
image_2.jpg
image_3.jpg
image_4.jpg

labels包含对应的标签(0-3)(labels.shape为(N,)):
1个
1个
3个
2个

我面临的问题是类别不平衡,类别 3 >> 1 > 2 > 0。

我想通过以下方式平衡最终数据集:

  • 计算每个类别中的图像(样本)数量
  • 获取图像数量最少的类别的数量
  • 使用该计数作为其他 3 个类别的最大图像/标签数
  • 随机弹出 imageslabels 中其他 3 个类的多余图像/标签

到目前为止,我正在使用 Counter 来确定每个类别的图像数量:

from Collections import Counter
import numpy as np

count = Counter(labels)
print(count)

>>>Counter({'1': 2991, '0': 2953, '2': 2510, '3': 2488})

你会如何建议我从 imageslabels 中随机弹出匹配元素,以便它们包含 2488 个类 0、1 和 2 的样本?

最佳答案

您可以使用 np.random.choice 创建一个整数值掩码,您可以将其应用于标签和图像以平衡数据集:

n = 2488

mask = np.hstack([np.random.choice(np.where(labels == l)[0], n, replace=False)
for l in np.unique(labels)])

关于python - 如何平衡 numpy 数组中的类?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44232900/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com