gpt4 book ai didi

python - 将 class_weight 添加到 .fit_generator() 会破坏 to_categorical()

转载 作者:太空宇宙 更新时间:2023-11-03 20:26:45 25 4
gpt4 key购买 nike

尝试使用 DataGenerator 类用一堆图像训练 CNN,模型正常工作得很好。问题是训练数据集非常偏向于几个类,所以我想添加 class_weights。但是,每次执行此操作时,将标记类转换为独热数组的代码部分都会出现索引错误。

这适用于在 tensorflow 之上运行的 Keras。有问题的函数是 keras.utils.to_categorical()

这是分类函数:

for i, pdb_id in enumerate(list_enzymes_temp):
mat = precomputed_distance_matrix(pdb_id, self.dim)

X[i,] = mat.distance_matrix.reshape(*self.dim)

y[i] = int(self.labels[pdb_id.upper()][1]) - 1

return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

这是我用来生成权重的函数

def get_class_weights(dictionary, training_enzymes, mode):
'Gets class weights for Keras'
# Initialization
counter = [0 for i in range(6)]

# Count classes
for enzyme in training_enzymes:
counter[int(dictionary[enzyme.upper()][1])-1] += 1
majority = max(counter)

# Make dictionary
class_weights = {i: float(majority/count) for i, count in enumerate(counter)}

# Value according to mode
if mode == 'unbalanced':
for key in class_weights:
class_weights[key] = 1
elif mode == 'balanced':
pass
elif mode == 'mean_1_balanced':
for key in class_weights:
class_weights[key] = (1+class_weights[key])/2

return class_weights

和我的 fit_generator 函数:

model.fit_generator(generator=training_generator,
validation_data=validation_generator,
epochs=max_epochs,
max_queue_size=16,
class_weight=class_weights,
callbacks=[tensorboard])

这里不会出现 IndexError 消息,并且模型在没有添加 class_weights 的情况下可以完美运行:

File "C:\Users\Python\DMCNN\data_generator.py", line 73, in __getitem__
X, y = self.__data_generation(list_enzymes_temp)
File "C:\Users\Python\DMCNN\data_generator.py", line 59, in __data_generation
return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
File "C:\Users\Python\Anaconda3\lib\site-packages\keras\utils\np_utils.py", line 34, in to_categorical
categorical[np.arange(n), y] = 1
IndexError: index 1065353216 is out of bounds for axis 1 with size 6

最佳答案

我在使用 keras.utils.to_categorical 时遇到了同样的错误。我得到的错误是“IndexError:索引 1065353216 超出了尺寸为 2 的轴 1 的范围”,因为我有 2 个类。

我相信这是将 1.0 转换为 1.0f(32 位浮点),因为 1065353216 是 32 位浮点值 1.0 的无符号 32 位整数表示(请在此处查看: Why is 1.0f in C code represented as 1065353216 in the generated assembly? )。就我而言,并非所有批处理都具有相同的长度,这最终会在 X 和 y 中出现一些未填充的空白,从而导致问题。你可以提前检查一下你的W(甚至X和Y)中是否还有一些元素未填写。您还可以看到 keras.utils.to_categorical 具有默认值 dtype='float32'。您可以尝试指定 dtype 例如在您的情况下“return X, keras.utils.to_categorical(y, num_classes=self.n_classes, dtype='uint8')”看看它是否有效。

关于python - 将 class_weight 添加到 .fit_generator() 会破坏 to_categorical(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57761731/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com