gpt4 book ai didi

python - 按一列中的特定值对行进行分组,并在 PyTorch 中计算平均值

转载 作者:行者123 更新时间:2023-12-04 09:38:49 31 4
gpt4 key购买 nike

样本张量:

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],  # class1
[ 6., 7., 8., 9., 10., 11.], # class3
[12., 13., 14., 15., 16., 17.], # class2
[18., 19., 20., 21., 22., 23.], # class0
[24., 25., 26., 27., 28., 29.]. # class1
])

预期结果:
tensor([[18., 19., 20., 21., 22., 23.], # class0
[12., 13., 14., 15., 16., 17.], # class1
[12., 13., 14., 15., 16., 17.], # class2
[ 6., 7., 8., 9., 10., 11.]. # class3
])

是否有纯 PyTorch 方法来实现这一点?

最佳答案

您可以使用 index_add 根据类索引添加然后除以每个标签的数量,使用 unique 计算:

# inputs
x = torch.arange(30.).view(5,6) # sample tensor
c = c = torch.tensor([1, 3, 2, 0, 1], dtype=torch.long) # class indices

# allocate space for output
result = torch.zeros((c.max() + 1, x.shape[1]), dtype=x.dtype)
# use index_add_ to sum up rows according to class
result.index_add_(0, c, x)
# use "unique" to count how many of each class
_, counts = torch.unique(c, return_counts=True)
# divide the sum by the counts to get the average
result /= counts[:, None]
result正如预期的那样:

Out[*]:
tensor([[18., 19., 20., 21., 22., 23.],
[12., 13., 14., 15., 16., 17.],
[12., 13., 14., 15., 16., 17.],
[ 6., 7., 8., 9., 10., 11.]])

关于python - 按一列中的特定值对行进行分组,并在 PyTorch 中计算平均值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62424100/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com