gpt4 book ai didi

python - 我在numpy中实现的多 channel 一维卷积有什么问题(与tensorflow相比)

转载 作者:行者123 更新时间:2023-12-01 06:29:13 27 4
gpt4 key购买 nike

为了确保我理解 TensorFlow 的卷积运算,我在 numpy 中实现了具有多个 channel 的 conv1d。但是,我得到不同的结果,我看不到问题所在。与 conv1d 相比,我的实现似乎将重叠值加倍。

代码:

import tensorflow as tf
import numpy as np

# hand-written multi-channel 1D convolution operator

# "Data", dimensions:
# [0]: sample (2 samples)
# [1]: time index (4 indexes)
# [2]: channels (2 channels)
x = np.array([[1,2,3,4],[5,6,7,8]]).T
x = np.array([x, x+8], dtype=np.float32)

# "Filter", a linear kernel to be convolved along axis 1
y = np.array([[[2,8,6,5,7],[3,9,7,2,1]]], dtype=np.float32)

# convolution along axis=1
w1 = np.zeros(x.shape[:2] + y.shape[2:])
for i in range(1,x.shape[1]-1):
w1[:,i-1:i+2,:] += x[:,i-1:i+2,:] @ y

# check against conv1d:
s = tf.Session()
w2 = s.run(tf.nn.conv1d(x, padding='VALID', filters=y))

然而,这对 w1 和 w2 给出了不同的结果:

In [13]: w1 # Numpy result
Out[13]:
array([[[ 17., 53., 41., 15., 12.],
[ 44., 140., 108., 44., 40.],
[ 54., 174., 134., 58., 56.],
[ 32., 104., 80., 36., 36.]],

[[ 57., 189., 145., 71., 76.],
[124., 412., 316., 156., 168.],
[134., 446., 342., 170., 184.],
[ 72., 240., 184., 92., 100.]]])

In [14]: w2 # Tensorflow result
Out[14]:
array([[[ 17., 53., 41., 15., 12.],
[ 22., 70., 54., 22., 20.],
[ 27., 87., 67., 29., 28.],
[ 32., 104., 80., 36., 36.]],

[[ 57., 189., 145., 71., 76.],
[ 62., 206., 158., 78., 84.],
[ 67., 223., 171., 85., 92.],
[ 72., 240., 184., 92., 100.]]], dtype=float32)

在我的版本中,与 conv1d 相比,重叠索引(中间 2 个)似乎翻了一番。但是,我不知道该怎么做,除法似乎不是正确的做法,因为卷积是一种简单的乘加运算。

知道我做错了什么吗?提前致谢!

编辑:我用 padding='SAME' 得到了相同的结果。

最佳答案

错误在for循环的+=中。您计算 w1[:,1,:]w1[:,2,:] 两次并将它们添加到自己。只需将 += 替换为 =,或者简单地执行以下操作:

>>> x @ y
array([[[ 17., 53., 41., 15., 12.],
[ 22., 70., 54., 22., 20.],
[ 27., 87., 67., 29., 28.],
[ 32., 104., 80., 36., 36.]],

[[ 57., 189., 145., 71., 76.],
[ 62., 206., 158., 78., 84.],
[ 67., 223., 171., 85., 92.],
[ 72., 240., 184., 92., 100.]]], dtype=float32)

关于python - 我在numpy中实现的多 channel 一维卷积有什么问题(与tensorflow相比),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59521443/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com