gpt4 book ai didi

python - 如何添加数组列表(张量)

转载 作者:太空宇宙 更新时间:2023-11-04 03:59:47 27 4
gpt4 key购买 nike

我正在定义一个简单的 conv2d 函数来计算输入和内核(均为 2D 张量)之间的互相关,如下所示:

import torch 

def conv2D(X, K):
h = K.shape[0]
w = K.shape[1]
ĥ = X.shape[0] - h + 1
ŵ = X.shape[1] - w + 1
Y = torch.zeros((ĥ, ŵ))
for i in range (ĥ):
for j in range (ŵ):
Y[i, j] = (X[i: i+h, j: j+w]*K).sum()

return Y

当 X 和 K 是三阶张量时,我计算每个 channel 的 conv2d,然后将它们加在一起,如下所示:

def conv2D_multiple(X, K):
cross = []
result = 0
for x, k in zip(X, K):
cross.append(conv2D(x,k))

for t in cross:
result += t

return result

测试我的功能:

X_2 = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]], 
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=torch.float32)
K_2 = torch.tensor([[[0, 1], [2, 3]], [[1, 2], [3, 4]]], dtype=torch.float32)

conv2D_multiple(X_2, K_2)

结果是:

tensor([[ 56.,  72.],
[104., 120.]])

结果符合预期,但是,我相信我的第二个conv2D_multiple(X, K) 函数中的 for 循环是多余的。我的问题是如何求和(元素明智)张量(数组)在列表中,所以我省略了第二个 for 循环。

最佳答案

由于您的 conv2D 对每个切片行为进行操作,您可以做的是分配一个 3D 张量,以便在您使用第一个 for 循环时存储结果通过获取每个结果并填充每个切片。然后,您可以使用 PyTorch 的内置 torch.sum 沿切片的维度求和。张量上的运算符以获得相同的结果。为了让它更容易接受,我将切片维度设置为 dim=0。因此,将 cross 从初始空列表替换为 3D 的 Torch 张量,以允许您存储中间结果,然后通过求和沿切片维度进行压缩。我们可以逃避这样做,因为您的初始实现将中间结果存储为 2D 张量列表。为了简化操作,转到 3D 并允许 PyTorch 沿切片轴求和。

这将要求您在循环之前首先为该 3D 张量定义正确的维度:

def conv2D_multiple(X, K):
h = K.shape[1]
w = K.shape[2]
ĥ = X.shape[1] - h + 1
ŵ = X.shape[2] - w + 1
c = X.shape[0]
cross = torch.zeros((c, ĥ, ŵ), dtype=torch.float32)
for i, (x, k) in enumerate(zip(X, K)):
cross[i] = conv2D(x,k)

result = cross.sum(dim=0)
return result

请注意,对于每个切片,您在输入和内核之间迭代,而不是附加到新列表,我们直接将其放入中间张量的切片中。存储这些结果后,沿切片轴求和,最终将其压缩为您期望的结果。使用您的示例输入运行上面的新函数会生成相同的结果。


如果这不是您想要的结果,另一种方法是简单地获取您创建的张量列表,通过使用 torch.stack 将它们全部堆叠在一起来构建中间张量。和总结。默认情况下,它沿第一个轴堆叠 (dim=0):

def conv2D_multiple(X, K):
cross = []
result = 0
for x, k in zip(X, K):
cross.append(conv2D(x,k))

cross = torch.stack(cross)
result = cross.sum(dim=0)
return result

关于python - 如何添加数组列表(张量),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58602039/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com