gpt4 book ai didi

pytorch - PyTorch 中复杂掩码的最大池化

转载 作者:行者123 更新时间:2023-12-04 12:02:30 30 4
gpt4 key购买 nike

假设我有一个矩阵 src带形状(5, 3)和一个 bool 矩阵 adj带形状(5, 5)如下,

src = tensor([[ 0,  1,  2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
adj = tensor([[1, 0, 1, 1, 0],
[0, 1, 1, 1, 0],
[1, 1, 0, 1, 1],
[1, 1, 1, 0, 0],
[0, 0, 1, 0, 1]])
我们可以取 src中的每一行作为一个节点嵌入,并把 adj中的每一行作为指示哪些节点是邻域的指标。
我的目标是在 src 中为每个节点的所有邻域节点嵌入中操作最大池化。 .
例如,作为第 0 个节点的邻域节点(包括它自己)是 0, 2, 3 ,因此我们在 [0, 1, 2] 上计算最大池化, [6, 7, 8] , [ 9, 10, 11]并领导更新嵌入 [ 9, 10, 11]更新 src_update 中的第 0 个节点.
我写的一个简单的解决方案是
src_update = torch.zeros_like(src)
for index in range(adj.size(0)):
list_of_non_zero = adj[index].nonzero().view(-1)
mat_non_zero = torch.index_select(src, 0, list_of_non_zero)
src_update[index] = torch.sum(mat_non_zero, dim=0)
src_update更新为:
tensor([[ 9, 10, 11],
[ 9, 10, 11],
[12, 13, 14],
[ 6, 7, 8],
[12, 13, 14]])
虽然能用,但是运行很慢,看起来不优雅!
任何改进它的建议 更高的效率 ?
另外,如果两者都 srcadj附加 批次 ( (batch, 5, 3) , (batch, 5, 5) ),如何让它工作?

最佳答案

我正在试验你的代码:

output = torch.zeros_like(src)
for index in range(adj.size(0)):
nz = adj[index].nonzero().view(-1)
output[index] = src.index_select(0, nz).max(0).values
瓶颈当然是 for 循环。首先想到的是使用某种分散功能。然而,这里的主要问题是邻居的数量可能因行而异。这意味着我们将无法在最大池化之前构建包含候选节点的张量。

一种可能的解决方案是创建一个类似于 src 的辅助张量其中第一个节点将包含占位符值(这些 不应 被最大池化选择,即我们可以使用 -inf )。我们可以使用包含索引的张量来索引这个张量:与您的方法相比,而不是使用 torch.nonzero() 删除零,我们将放置一个索引值 0(指的是 modified- src 中第一个位置的占位符行)。
在实践中,它是这样的:
对于辅助张量 src_ ,我放置了 -1 s 作为占位符值。
>>> src_ = torch.cat((-torch.ones_like(src[:1]), src))
tensor([[-inf, -inf, -inf],
[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.],
[ 12., 13., 14.]])
我们可以转换 adj矩阵转换成索引张量:
>>> index = torch.arange(1, adj.size(1) + 1)*adj
tensor([[1, 0, 3, 4, 0],
[0, 2, 3, 4, 0],
[1, 2, 0, 4, 5],
[1, 2, 3, 0, 0],
[0, 0, 3, 0, 5]])
为了更容易索引,我们将扁平化 index , 索引 src_在第一个轴上,然后立即 reshape :
>>> indexed = src_[index.flatten(), :].reshape(*adj.shape, 3)
tensor([[[ 0., 1., 2.],
[-inf, -inf, -inf],
[ 6., 7., 8.],
[ 9., 10., 11.],
[-inf, -inf, -inf]],

...

[[-inf, -inf, -inf],
[-inf, -inf, -inf],
[ 6., 7., 8.],
[-inf, -inf, -inf],
[ 12., 13., 14.]]])
最后你可以最大池:
>>> indexed.max(dim=1).values
tensor([[ 9., 10., 11.],
[ 9., 10., 11.],
[12., 13., 14.],
[ 6., 7., 8.],
[12., 13., 14.]])

关于pytorch - PyTorch 中复杂掩码的最大池化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64022697/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com