gpt4 book ai didi

python - PyTorch 等效于 index_add_ 而不是取最大值

转载 作者:太空宇宙 更新时间:2023-11-04 04:41:08 27 4
gpt4 key购买 nike

在 PyTorch 中,张量的 index_add_ 方法使用提供的索引张量进行求和:

idx = torch.LongTensor([0,0,0,0,1,1])
child = torch.FloatTensor([1, 3, 5, 10, 8, 1])
parent = torch.FloatTensor([0, 0])
parent.index_add_(0, idx, child)

前四个子值求和到 parent[0],接下来的两个值求和到 parent[1],所以结果是 tensor([ 19., 9.])

但是,我需要改为执行 index_max_,这在 API 中不存在。有没有办法有效地做到这一点(不必循环或分配更多内存)?一个(坏的)循环解决方案是:

for i in range(max(idx)+1):
parent[i] = torch.max(child[idx == i])

这会产生 tensor([ 10., 8.]) 的预期结果,但速度非常慢。

最佳答案

使用索引的解决方案:

def index_max(child, idx, num_partitions): 
# Building a num_partition x num_samples matrix `idx_tiled`:
partition_idx = torch.range(0, num_partitions - 1, dtype=torch.long)
partition_idx = partition_idx.view(-1, 1).expand(num_partitions, idx.shape[0])
idx_tiled = idx.view(1, -1).repeat(num_partitions, 1)
idx_tiled = (idx_tiled == partition_idx).float()
# i.e. idx_tiled[i,j] == 1 if idx[j] == i, else 0

parent = idx_tiled * child
parent, _ = torch.max(parent, dim=1)
return parent

基准测试:

import timeit

setup = '''
import torch

def index_max_v0(child, idx, num_partitions):
parent = torch.zeros(num_partitions)
for i in range(max(idx) + 1):
parent[i] = torch.max(child[idx == i])
return parent

def index_max(child, idx, num_partitions):

# Building a num_partition x num_samples matrix `idx_tiled`
# containing for each row indices of
partition_idx = torch.range(0, num_partitions - 1, dtype=torch.long)
partition_idx = partition_idx.view(-1, 1).expand(num_partitions, idx.shape[0])
idx_tiled = idx.view(1, -1).repeat(num_partitions, 1)
idx_tiled = (idx_tiled == partition_idx).float()

parent = idx_tiled * child
parent, _ = torch.max(parent, dim=1)
return parent

idx = torch.LongTensor([0,0,0,0,1,1])
child = torch.FloatTensor([1, 3, 5, 10, 8, 1])
num_partitions = torch.unique(idx).shape[0]

'''
print(min(timeit.Timer('index_max_v0(child, idx, num_partitions)', setup=setup).repeat(5, 1000)))
# > 0.05308796599274501
print(min(timeit.Timer('index_max(child, idx, num_partitions)', setup=setup).repeat(5, 1000)))
# > 0.024736385996220633

关于python - PyTorch 等效于 index_add_ 而不是取最大值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50605205/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com