gpt4 book ai didi

deep-learning - 如何有效地计算 PyTorch 中的批量成对距离

转载 作者:行者123 更新时间:2023-12-03 17:20:17 25 4
gpt4 key购买 nike

我有形状为 X 的张量 BxNxD和 Y 形状 BxNxD .

我想计算批次中每个元素的成对距离,即 I a BxMxN张量。

我该怎么做呢?

这里有一些关于这个话题的讨论:https://github.com/pytorch/pytorch/issues/9406 ,但我不明白,因为有很多实现细节,而没有突出显示实际解决方案。

一种天真的方法是使用此处讨论的非成对成对距离的答案:https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065 , IE。

import torch
import numpy as np

B = 32
N = 128
M = 256
D = 3

X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))


def pairwise_distances(x, y=None):
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)

dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
return torch.clamp(dist, 0.0, np.inf)


out = []
for b in range(B):
out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)

如何在不循环 B 的情况下执行此操作?
谢谢

最佳答案

我有一个类似的问题,并花了一些时间找到最简单和最快的解决方案。现在您可以使用 PyTorch cdist 计算批量距离这会给你 BxMxN张量:

torch.cdist(Y, X)

此外,如果您只想计算两个矩阵的每对行之间的距离,它也能很好地工作。

关于deep-learning - 如何有效地计算 PyTorch 中的批量成对距离,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55126072/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com