gpt4 book ai didi

pytorch - torch.masked_scatter 结果不符合预期

转载 作者:行者123 更新时间:2023-12-05 04:46:57 37 4
gpt4 key购买 nike

我的 torch 代码:

import torch
x = torch.tensor([[0.3992, 0.2908, 0.9004, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
print(x.shape)
mask = torch.tensor([[False, False, True, False, True],
[ True, True, True, False, False]])
print(mask.shape)
y = torch.tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
print(y.shape)
y.masked_scatter_(mask, x)
print(y)

结果是:

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 5])
tensor([[0.0000, 0.0000, 0.3992, 0.0000, 0.2908],
[0.9004, 0.4850, 0.6004, 0.0000, 0.0000]])

我认为结果答案是:

tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
[0.5375, 0.9006, 0.6797, 0.0000, 0.0000]])

我的pytorch版本是pytorch1.4

最佳答案

你是对的,这很困惑,而且几乎没有文档。

然而,散点图的工作方式(正如您所发现的)是连续第 iTrue 被赋予第 i 个来自源头的值(value)。所以不是对应于 True 位置的值。

幸运的是,您可以使用普通的索引符号轻松实现您正在尝试做的事情:

>>> y[mask] = x[mask]
>>> y
tensor([[0.0000, 0.0000, 0.9004, 0.0000, 0.6004],
[0.5735, 0.9006, 0.6797, 0.0000, 0.0000]])

关于pytorch - torch.masked_scatter 结果不符合预期,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68675160/

37 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com