gpt4 book ai didi

python - 在 keras 中可视化训练神经网络的权重

转载 作者:行者123 更新时间:2023-12-04 13:15:08 25 4
gpt4 key购买 nike

嗨,我训练了一个卷积层为 96*96*32 的自动编码器网络

现在我得到了名为 autoencoder 的模型的权重

layer=autoencoder.layers[1]
W=layer.get_weights()

由于 w 是一个列表,请帮助我对它的元素进行排序并可视化经过训练的内核。
我猜它应该是 96×96 大小的 32 个内核

当我打字
len(w)

它给了 2 所以我有 2 个数组

顶部数组有 9 个子数组,每个子数组有 32 个数字
最后一个数组有 32 个元素。所以一定是偏见

[array([[[[-6.56146603e-03, -1.51752336e-02, -3.76937017e-02,
-4.55160812e-03, 1.26366820e-02, -2.97747254e-02,
3.76312323e-02, -1.56892575e-02, 2.03932393e-02,
3.29606095e-03, 3.76580656e-02, 6.99581252e-03,
-4.97130565e-02, 3.63005586e-02, 3.70187908e-02,
2.63699284e-03, 4.42482866e-02, 8.26128479e-03,
3.44854854e-02, 1.94760375e-02, 3.91177870e-02,
-6.67006942e-03, 5.64308763e-02, -1.55166145e-02,
-3.46037326e-03, -3.14556211e-02, -2.31548538e-03,
5.77888393e-04, 2.17472352e-02, -8.16953406e-02,
1.54041937e-02, -3.55066173e-02]],

[[ 7.61649990e-03, -6.52475432e-02, 2.02584285e-02,
-4.36152853e-02, -7.94242844e-02, -6.29556971e-03,
-2.17294712e-02, 3.30206454e-02, 3.47386077e-02,
-2.77627818e-03, 4.49984707e-02, -3.03241126e-02,
-3.36903334e-02, 2.34354921e-02, 3.31020765e-02,
-7.81059638e-03, -9.54489596e-03, -1.07985372e-02,
4.10569459e-02, 5.06392084e-02, -1.64809041e-02,
8.42852518e-03, -6.24148361e-03, 1.38165271e-02,
4.47277874e-02, -1.68551356e-02, 2.87279133e-02,
-4.17906158e-02, -3.29194516e-02, 5.37550561e-02,
-3.10864598e-02, -4.53849025e-02]],

[[ 5.37880100e-02, 2.00091377e-02, -8.04780126e-02,
2.05146279e-02, -6.41385652e-03, 2.94176023e-02,
2.42049675e-02, 2.98423916e-02, 1.30865928e-02,
-9.23016574e-03, -2.63463743e-02, -1.58412699e-02,
-4.76215854e-02, -1.53328422e-02, -2.54222248e-02,
1.03113698e-02, 1.97005924e-02, -1.09527409e-02,
-4.29149866e-02, 1.15255425e-02, 3.65356952e-02,
2.26275604e-02, 8.76231957e-03, -1.82650369e-02,
4.30952013e-02, -1.58966344e-03, 1.01399068e-02,
7.15927547e-03, 2.70794444e-02, -1.93151142e-02,
2.06329934e-02, -3.24055366e-02]]],


[[[ 7.32885906e-04, -5.99233769e-02, 1.01583647e-02,
2.62707975e-02, -1.60765275e-02, 4.54364009e-02,
1.22182900e-02, 1.77695882e-02, 3.40870097e-02,
-3.20678158e-03, 1.94115974e-02, -5.89495376e-02,
5.51430099e-02, 1.08586736e-02, -2.14386974e-02,
-1.10124948e-03, -1.41514605e-02, -8.40184465e-03,
-4.09237854e-02, 2.27938611e-02, 2.82027805e-03,
3.99805643e-02, -5.23957238e-02, -6.65743649e-02,
-1.86213956e-03, 1.84283289e-03, 8.22036352e-04,
-2.04587094e-02, -4.95675243e-02, 5.40869832e-02,
4.00022417e-02, -4.74570543e-02]],

[[-3.73015292e-02, 9.84914601e-03, 9.94246900e-02,
3.19805741e-02, 8.14174674e-03, 2.72354241e-02,
-1.58177980e-03, -5.65455444e-02, -2.13499945e-02,
2.36055311e-02, 4.57456382e-03, 5.87781705e-02,
-4.50953143e-03, -3.05559561e-02, 8.65572542e-02,
-2.87776738e-02, 7.56273838e-03, -2.02421043e-02,
4.32164557e-02, 1.07650533e-02, 1.74834915e-02,
-2.26386450e-02, -4.51299828e-03, -7.19766971e-03,
-5.64673692e-02, -3.46505865e-02, -9.57003422e-03,
-4.17267382e-02, 2.74983943e-02, 7.50013590e-02,
-1.39447292e-02, -2.10063234e-02]],

[[-4.99953330e-03, -1.95915010e-02, 7.38414973e-02,
3.00457701e-02, 4.11909744e-02, -4.93509434e-02,
-3.72827090e-02, -4.84874584e-02, -1.73344277e-02,
2.13540550e-02, 2.63152272e-02, 5.11181913e-02,
5.94335012e-02, -8.46157200e-04, -3.79960015e-02,
-2.01609023e-02, 2.21411046e-02, -1.14003820e-02,
-1.78077854e-02, -6.17240835e-03, -9.96494666e-03,
-2.70768851e-02, 3.32489684e-02, -1.18451891e-02,
7.48611614e-02, 3.68427448e-02, -1.70680200e-04,
2.78645731e-03, 3.37152109e-02, -6.00774325e-02,
3.43431458e-02, 6.80516511e-02]]],


[[[ 4.51148823e-02, 4.12209071e-02, -1.92945134e-02,
-2.68811788e-02, 4.68725041e-02, -2.08357088e-02,
-3.62888947e-02, -1.60191804e-02, 3.19913588e-02,
1.54639455e-02, -7.92380888e-03, -4.85247411e-02,
-3.52074914e-02, -1.04825860e-02, -6.63231388e-02,
4.35819328e-02, 1.74060687e-02, -3.14022303e-02,
-2.88435258e-02, -2.56987382e-03, -4.61222306e-02,
9.01424140e-03, -3.54990773e-02, 3.61517034e-02,
-4.51472104e-02, -1.96188372e-02, 2.76502203e-02,
-3.39846462e-02, -5.75804268e-04, -4.55158725e-02,
2.47761561e-03, 5.08131757e-02]],

[[ 3.74217257e-02, 4.53428067e-02, -4.36269939e-02,
-1.65079869e-02, -2.69084796e-02, -2.38134293e-03,
2.26788968e-02, -3.10470518e-02, -4.33242172e-02,
1.89485904e-02, -5.52747138e-02, 6.01334386e-02,
-1.70235410e-02, -4.17503342e-02, -1.59652822e-03,
-3.10646854e-02, -1.94913559e-02, 5.42740058e-03,
5.47912866e-02, 2.19548331e-03, -2.94116754e-02,
2.24571414e-02, -1.57341175e-02, -5.24678500e-03,
4.41270098e-02, 1.79115515e-02, -3.40841003e-02,
-2.95497216e-02, 4.40835916e-02, 4.28234115e-02,
-4.25039157e-02, 5.90493456e-02]],

[[-2.71476209e-02, 6.84098527e-02, -2.91980486e-02,
-2.52507403e-02, -6.22444265e-02, 3.67519422e-03,
5.06899729e-02, 3.09969904e-03, 4.50362265e-02,
8.56801707e-05, 4.21552844e-02, -3.78406122e-02,
-1.73772611e-02, 4.68185954e-02, -6.93227863e-03,
-4.71074954e-02, 5.72011899e-03, -1.59831103e-02,
-1.66428182e-02, 1.12894354e-02, 5.62585844e-03,
1.36870472e-02, -2.89466791e-02, -2.87153292e-03,
-3.21626514e-02, -3.75866666e-02, -1.62240565e-02,
3.01954672e-02, -2.69964593e-03, -2.27513053e-02,
2.10835561e-02, -4.13369946e-02]]]], dtype=float32),
array([-1.1922461e-03, -2.0752363e-04, 1.1357996e-05, 1.6377015e-05,
-2.5950783e-04, 1.9307183e-05, -1.5572178e-06, -1.3648998e-03,
-8.6763187e-04, 4.4856939e-04, 2.7988455e-03, -7.7398616e-04,
-5.1178242e-04, -6.8265648e-04, 1.8571866e-04, -7.1992702e-04,
-5.5880222e-04, -3.6114815e-04, -9.7678707e-04, 2.6443407e-03,
1.1190268e-03, -1.0251488e-03, -1.1638318e-03, 7.1209669e-04,
4.9417594e-04, 2.3746442e-04, -4.8552561e-04, 1.4480414e-03,
-1.8445569e-05, 4.2989667e-04, 1.0579359e-04, -3.2821635e-04],
dtype=float32)]


模型几个起始层的总结

Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 96, 96, 1) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 96, 96, 32) 320 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 96, 96, 32) 128 conv2d_1[0][0]



现在我如何订购它们并可视化

我正在使用 keras

谢谢

最佳答案

通常,如果您使用的是 Dense 层,那么第一个 lenth 2 对应于权重向量和偏置向量。

由于我不知道您的图层类型,因此我正在添加一个示例来解释 Dense、Conv2D 图层的形状。

第一个长度总是对应于权重和偏差,权重和偏差的第二个形状不同,对于偏差,它总是一个数组,对于 Dense,权重有一个形状 (input_dim, output_dim),对于 Conv2D (channels, kernel_h, kernel_w, num_filters)。

from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import numpy as np

i1 = Input(shape=(32,32,3))
c1 = Conv2D(32, 3)(i1)
f1 = Flatten()(c1)
d1 = Dense(5)(f1)

m = Model(i1, d1)

m.summary()

y = m(np.zeros((1, 32, 32, 3)))

print(m.layers)
cw1 = np.array(m.layers[1].get_weights())
print(cw1.shape) # 2 weight, 1 weight, 1 bias
print(cw1[0].shape) # 3 channels, 3 by 3 kernels, 32 filters
print(cw1[1].shape) # 32 biases

cw1 = np.array(m.layers[2].get_weights())
print(cw1.shape) # this is just a flatten operations, so no weights

cw1 = np.array(m.layers[3].get_weights())
print(cw1.shape) # 2 -> 1 weight, 1 bias
print(cw1[0].shape) # 28800 inputs, 5 outputs, 28800 by 5 weight matrix
print(cw1[1].shape) # 5 biases
Model: "model_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv2d_13 (Conv2D) (None, 30, 30, 32) 896
_________________________________________________________________
flatten_13 (Flatten) (None, 28800) 0
_________________________________________________________________
dense_13 (Dense) (None, 5) 144005
=================================================================
Total params: 144,901
Trainable params: 144,901
Non-trainable params: 0
_________________________________________________________________
[<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fb8ce3bb828>, <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb8ce5fd6d8>, <tensorflow.python.keras.layers.core.Flatten object at 0x7fb8ce3bb940>, <tensorflow.python.keras.layers.core.Dense object at 0x7fb8ce3bbb70>]
(2,)
(3, 3, 3, 32)
(32,)
(0,)
(2,)
(28800, 5)
(5,)


可视化完全取决于维度。

如果是一维,
import matplotlib.pyplot as plt
plt.plot(weight)
plt.show()

如果是二维的,
import matplotlib.pyplot as plt
plt.imshow(weight)
plt.show()

如果是3D,

您可以选择一个 channel 并仅绘制该部分。

# plotting the 32 conv filter
import matplotlib.pyplot as plt
cw1 = np.array(m.layers[1].get_weights())
for i in range(32):
plt.imshow(cw1[0][:,:,:,i])
plt.show()

enter image description here

关于python - 在 keras 中可视化训练神经网络的权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61288116/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com