gpt4 book ai didi

Pytorch 以交替顺序连接行

转载 作者:行者123 更新时间:2023-12-05 02:08:22 25 4
gpt4 key购买 nike

我正在尝试对变形金刚论文中的位置编码进行编码。为此,我需要执行类似于以下的操作:

a = torch.arange(20).reshape(4,5) 
b = a * 2
c = torch.cat([torch.stack([a_row,b_row]) for a_row, b_row in zip(a,b)])

我觉得可能有更快的方法来完成上述操作?也许通过在 a 和 b 上添加一个维度?

最佳答案

为此我会简单地使用赋值运算符:

c = torch.zeros(8, 5)
c[::2, :] = a # Index every second row, starting from 0
c[1::2, :] = b # Index every second row, starting from 1

在对这两个解决方案进行计时时,我使用了以下内容:

import timeit
import torch
a = torch.arange(20).reshape(4,5)
b = a * 2

suggested = timeit.timeit("c = torch.cat([torch.stack([a_row, b_row]) for a_row, b_row in zip (a, b)])",
setup="import torch; from __main__ import a, b", number=10000)
print(suggested/10000)
# 4.5105120493099096e-05

improved = timeit.timeit("c = torch.zeros(8, 5); c[::2, :] = a; c[1::2, :] = b",
setup="import torch; from __main__ import a, b", number=10000)
print(improved/10000)
# 2.1489459509029985e-05

第二种方法花费的时间始终较少(大约一半),即使单次迭代仍然非常快。当然,您必须针对实际张量大小对此进行测试,但这是我能想到的最直接的解决方案。迫不及待地想看看是否有人对此有一些更快的漂亮的低级解决方案!

此外,请记住,我没有为 b 的创建计时,假设您想要交织的张量已经给出。

关于Pytorch 以交替顺序连接行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61026393/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com