gpt4 book ai didi

tensorflow - 不平衡图像数据集 (Tensorflow2)

转载 作者:行者123 更新时间:2023-12-04 17:21:19 24 4
gpt4 key购买 nike

我正在尝试解决二值图像分类问题,但这两个类(分别为 ~590 和 ~5900 个实例,分别对应于类 1 和 2)严重偏斜,但仍然截然不同。

有什么办法可以解决这个问题,我想尝试 SMOTE/随机加权过采样。

我尝试了很多不同的方法,但我被卡住了。我试过使用 class_weights=[10,1][5900,590][1/5900,1/590]而我的模型仍然只预测 2 类。我试过使用 tf.data.experimental.sample_from_datasets 但我无法让它工作。我什至尝试过使用 sigmoid 焦点交叉熵损失,这有很大帮助但还不够。

我希望能够对 1 类进行 10 倍的过采样,我尝试过的唯一有点奏效的方法是手动过采样,即复制训练目录的 1 类实例以匹配 2 类中的实例数。

没有更简单的方法吗,我使用的是 Google Colab,所以这样做效率极低。

有没有办法在数据生成器或类似工具中指定 SMOTE 参数/过采样?

data/
...class_1/
........image_1.jpg
........image_2.jpg
...class_2/
........image_1.jpg
........image_2.jpg

我的数据是上面显示的形式。

TRAIN_DATAGEN = ImageDataGenerator(rescale = 1./255.,
rotation_range = 40,
width_shift_range = 0.2,
height_shift_range = 0.2,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)

TEST_DATAGEN = ImageDataGenerator(rescale = 1.0/255.)

TRAIN_GENERATOR = TRAIN_DATAGEN.flow_from_directory(directory = TRAIN_DIR,
batch_size = BACTH_SIZE,
class_mode = 'binary',
target_size = (IMG_HEIGHT, IMG_WIDTH),
subset = 'training',
seed = DATA_GENERATOR_SEED)

VALIDATION_GENERATOR = TEST_DATAGEN.flow_from_directory(directory = VALIDATION_DIR,
batch_size = BACTH_SIZE,
class_mode = 'binary',
target_size = (IMG_HEIGHT, IMG_WIDTH),
subset = 'validation',
seed = DATA_GENERATOR_SEED)
...
...
...

HISTORY = MODEL.fit(TRAIN_GENERATOR,
validation_data = VALIDATION_GENERATOR,
epochs = EPOCHS,
verbose = 2,
callbacks = [EARLY_STOPPING],
class_weight = CLASS_WEIGHT)

我是 Tensorflow 的新手,但我对整个 ML 有一些经验。我多次想切换到 PyTorch,因为它们具有数据加载器的参数,可以使用 sampler=WeightedRandomSampler 自动(过度/不足)采样。

注意:我看过很多关于如何过采样的教程,但没有一个是图像分类问题,我想坚持使用 TF/Keras,因为它可以轻松进行迁移学习,你们能帮忙吗?

最佳答案

您可以使用此策略根据不平衡计算权重:

from sklearn.utils import class_weight 
import numpy as np

class_weights = class_weight.compute_class_weight(
'balanced',
np.unique(train_generator.classes),
train_generator.classes)

train_class_weights = dict(enumerate(class_weights))
model.fit_generator(..., class_weight=train_class_weights)

关于tensorflow - 不平衡图像数据集 (Tensorflow2),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66016844/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com