gpt4 book ai didi

machine-learning - 多标签分类 keras 的奇怪准确性

转载 作者:行者123 更新时间:2023-11-30 08:25:43 24 4
gpt4 key购买 nike

我有一个多标签分类问题,我使用了以下代码,但验证准确率在第一个时期跃升至 99%,考虑到数据的复杂性,这很奇怪,因为输入特征是从 inception 模型中提取的 2048 个特征 (pool3:0 )层和标签是[1000],(这里是包含特征和标签样本的文件的链接: https://drive.google.com/file/d/0BxI_8PO3YBPPYkp6dHlGeExpS1k/view?usp=sharing ), 我在这里做错了什么吗?

注意:标签是稀疏向量,仅包含 1 ~ 10 个条目,其中 1 其余为零

model.compile(optimizer='adadelta', loss='binary_crossentropy', metrics=['accuracy']) 

预测的输出为零!

我在训练模型来干扰预测时做错了什么?

#input is the features file and labels file

def generate_arrays_from_file(path ,batch_size=100):
x=np.empty([batch_size,2048])
y=np.empty([batch_size,1000])
while True:
f = open(path)
i = 1
for line in f:
# create Numpy arrays of input data
# and labels, from each line in the file
words=line.split(',')
words=map(float, words[1:])
x_= np.array(words[0:2048])
y_=words[2048:]
y_= np.array(map(int,y_))
x_=x_.reshape((1, -1))
#print np.squeeze(x_)
y_=y_.reshape((1,-1))
x[i]= x_
y[i]=y_
i += 1
if i == batch_size:
i=1
yield (x, y)

f.close()

model = Sequential()
model.add(Dense(units=2048, activation='sigmoid', input_dim=2048))
model.add(Dense(units=1000, activation="sigmoid",
kernel_initializer="uniform"))
model.compile(optimizer='adadelta', loss='binary_crossentropy', metrics=
['accuracy'])

model.fit_generator(generate_arrays_from_file('train.txt'),
validation_data= generate_arrays_from_file('test.txt'),
validation_steps=1000,epochs=100,steps_per_epoch=1000,
verbose=1)

最佳答案

我认为准确性的问题在于你的输出稀疏。

Keras 使用以下公式计算准确性:

K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)

因此,在您的情况下,只有 1~10 个非零标签,全 0 的预测将产生 99.9% ~ 99% 的准确度。

就不学习的问题而言,我认为问题在于您使用 sigmoid 作为最后的激活并使用 0 或 1 作为输出值。这是不好的做法,因为为了让 sigmoid 返回 0 或 1,它作为输入获得的值必须非常大或非常小,这反射(reflect)了网络具有非常大的(绝对值)权重。此外,由于在每个训练输出中,1 远小于 0,因此网络很快就会达到一个稳定点,在该点上它只是输出全零(这种情况下的损失也不是很大,应该在 0.016~0.16 左右)。

您可以做的是缩放输出标签,使其介于 (0.2, 0.8) 之间,这样网络的权重就不会变得太大或太小。或者,您可以使用 relu 作为激活函数。

关于machine-learning - 多标签分类 keras 的奇怪准确性,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44833816/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com