gpt4 book ai didi

python - 为什么 Keras Conv1D 权重在训练期间没有改变?

转载 作者:行者123 更新时间:2023-12-01 08:07:00 24 4
gpt4 key购买 nike

我仅使用一个卷积层(8 个长度为 10 的过滤器)初始化网络。

# Initialize Convolutional Neural Network
cnn = Sequential()
conv = Conv1D(filters=8, kernel_size=10, strides=1, padding="same", input_shape=(train.values.shape[1]-1, 1))
cnn.add(conv)
cnn.add(Activation("relu"))
cnn.add(MaxPooling1D(pool_size=2, strides=2, padding="same"))
cnn.add(Flatten())
cnn.add(Dense(2, activation='softmax'))
cnn.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
cnn.summary()

我在训练前和训练后分别获取一次权重,并用我编写的函数绘制它们。

w1 = conv.get_weights()[0][:, 0, :]
print(w1[:,0])
plot_weights(w1)

# Fit CNN
y = to_categorical(train.values[:, -1])
X_cnn = np.expand_dims(train.values[:, :-1], axis=2)
start = time.time()
cnn.fit(X_cnn, y, verbose=1, batch_size=20, validation_split=0.2, epochs=20)
end = time.time()

w2 = conv.get_weights()[0][:, 0, :]
print(w2[:,0])
plot_weights(w2)

绘制权重的函数:

def plot_weights(w):
w_min = w.min()
w_max = w.max()
n = w.shape[0]
fig, axes = plt.subplots(nrows=8, ncols=1)
for i, ax in enumerate(axes.flat):
im = ax.imshow(w[:, i].reshape(1, n), vmin=w_min, vmax=w_max, interpolation="nearest",
cmap="gray") # Display weights as image
plt.setp(ax.get_yticklabels(), visible=False) # Hide y ticks
ax.tick_params(axis='y', which='both', length=0) # Set length of y ticks to 0

fig.colorbar(im, ax=axes.ravel().tolist())
plt.show(block=False)

return

输出如下所示:

Before training

After training

当我在训练前后打印第一个过滤器时,您还可以看到它是完全相同的数字(甚至没有稍微改变)。

>>>[-0.20076838  0.03835052 -0.04454999 -0.20220913  0.24402907  0.03407234
-0.09768075 0.16887552 0.12767741 0.00756356]
>>>[-0.20076838 0.03835052 -0.04454999 -0.20220913 0.24402907 0.03407234
-0.09768075 0.16887552 0.12767741 0.00756356]

这种行为的原因是什么?难道我做错了什么?网络显然正在学习一些东西,因为我的准确率接近 100%。

--ga97dil

最佳答案

您可能需要访问正在训练的模型本身,即 cnn而不是用于初始化图层的定义,即 conv .

尝试cnn.layers[0].get_weights()[:, 0, :]而不是conv.get_weights()[0][:, 0, :] .

关于python - 为什么 Keras Conv1D 权重在训练期间没有改变?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55490811/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com