gpt4 book ai didi

python - 在 Pytorch 中计算欧几里德范数.. 麻烦理解和实现

转载 作者:太空狗 更新时间:2023-10-30 02:24:43 24 4
gpt4 key购买 nike

我看到另一个 StackOverflow 线程讨论了计算欧几里得范数的各种实现,但我无法理解特定实现的工作原理/方式。

代码可在 MMD 指标的实现中找到:https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/statistics_diff.py

这是一些开始的样板:

import torch
sample_1, sample_2 = torch.ones((10,2)), torch.zeros((10,2))

然后下一部分是我们从上面的代码中提取的内容..我不确定为什么样本被连接在一起..

sample_12 = torch.cat((sample_1, sample_2), 0)
distances = pdist(sample_12, sample_12, norm=2)

然后传递给 pdist 函数:

def pdist(sample_1, sample_2, norm=2, eps=1e-5):
r"""Compute the matrix of all squared pairwise distances.
Arguments
---------
sample_1 : torch.Tensor or Variable
The first sample, should be of shape ``(n_1, d)``.
sample_2 : torch.Tensor or Variable
The second sample, should be of shape ``(n_2, d)``.
norm : float
The l_p norm to be used.
Returns
-------
torch.Tensor or Variable
Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
``|| sample_1[i, :] - sample_2[j, :] ||_p``."""

这里我们进入计算的核心

    n_1, n_2 = sample_1.size(0), sample_2.size(0)
norm = float(norm)
if norm == 2.:
norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
norms = (norms_1.expand(n_1, n_2) +
norms_2.transpose(0, 1).expand(n_1, n_2))
distances_squared = norms - 2 * sample_1.mm(sample_2.t())
return torch.sqrt(eps + torch.abs(distances_squared))

我不明白为什么要这样计算欧几里德范数。任何见解将不胜感激

最佳答案

让我们逐步浏览一下这段代码。欧氏距离的定义,即L2范数是

enter image description here

让我们考虑最简单的情况。我们有两个样本,

enter image description here

示例 a 有两个向量 [a00, a01][a10, a11]。与示例 b 相同。让我们首先计算范数

n1, n2 = a.size(0), b.size(0)  # here both n1 and n2 have the value 2
norm1 = torch.sum(a**2, dim=1)
norm2 = torch.sum(b**2, dim=1)

现在我们得到

enter image description here

接下来,我们有 norms_1.expand(n_1, n_2)norms_2.transpose(0, 1).expand(n_1, n_2)

enter image description here

请注意,b 是转置的。两者之和给出 norm

enter image description here

sample_1.mm(sample_2.t()),就是两个矩阵相乘。

enter image description here

因此,手术后

distances_squared = norms - 2 * sample_1.mm(sample_2.t())

你得到

enter image description here

最后,最后一步是对矩阵中的每个元素求平方根。

关于python - 在 Pytorch 中计算欧几里德范数.. 麻烦理解和实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51986758/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com