gpt4 book ai didi

python-3.x - 场感知分解的向量化实现

转载 作者:行者123 更新时间:2023-11-30 08:42:15 25 4
gpt4 key购买 nike

我想以矢量化的方式实现场感知分解模型(FFM)。在FFM中,通过以下方程进行预测

$ \sum_{j_1=1}^n \sum_{j_2=j_1+1}^n (\textbf{w}_{j_1, f_2}, \textbf{w}_{j_2, f_1}) x_{j_1} x_{j_2} $

其中 w 是依赖于该特征和另一个特征的字段的嵌入。有关详细信息,请参阅 FFM 中的方程 (4) .

为此,我定义了以下参数:

import torch

W = torch.nn.Parameter(torch.Tensor(n_features, n_fields, n_factors), requires_grad=True)

现在,给定大小为 (batch_size, n_features) 的输入 x,我希望能够计算前面的方程。这是我当前的(非矢量化)实现:

total_inter = torch.zeros(x.shape[0])
for i in range(n_features):
for j in range(i + 1, n_features):
temp1 = torch.mm(
x[:, i].unsqueeze(1),
W[i, feature2field[j], :].unsqueeze(0))
temp2 = torch.mm(
x[:, j].unsqueeze(1),
W[j, feature2field[i], :].unsqueeze(0))
total_inter += torch.sum(temp1 * temp2, dim=1)

不出所料,这个实现速度非常慢,因为 n_features 很容易达到 1000!但请注意,x 的大部分条目均为 0。所有输入均表示赞赏!

编辑:

如果它能以任何方式提供帮助,以下是该模型在 PyTorch 中的一些实现:

不幸的是,我无法弄清楚他们到底是如何做到的。

额外更新:

我现在可以通过执行以下操作以更有效的方式获取 xW 的乘积:

temp = torch.einsum('ij, jkl -> ijkl', x, W)

因此,我的循环现在是:

total_inter = torch.zeros(x.shape[0])
for i in range(n_features):
for j in range(i + 1, n_features):
temp1 = temp[:, i, feature2field[j], :]
temp2 = temp[:, j, feature2field[i], :]
total_inter += 0.5 * torch.sum(temp1 * temp2, dim=1)

然而,由于该循环进行了大约 500 000 次迭代,这仍然太长了。

最佳答案

  1. 使用 pytorch sparse tensors 可能会帮助您加快乘法速度。 .

  2. 另外,可能有效的方法如下:创建 n 个数组,每个特征 i 一个数组,该数组将在每一行中保存其相应的场因子。例如对于特征 i = 0

[ W[0, feature2field[0], :],
W[0, feature2field[1], :],
W[0, feature2field[n], :]]

然后计算这些数组的乘法,我们称它们为 F,与 X

R[i] = F[i] * X

因此,R 中的每个元素都将保存 F[i] 与 X 相乘的结果(一个数组)。

接下来,您将每个 R[i] 与其转置相乘

R[i] = R[i] * R[i].T

现在您可以像以前一样在循环中进行求和

for i in range(n_features):
total_inter += torch.sum(R[i], dim=1)

请对此持保留态度,因为我还没有测试过。无论如何,我认为它会给你指明正确的方向。

可能出现的一个问题是转置乘法,其中每个元素也将与其自身相乘,然后添加到总和中。我不认为这会影响分类器,但无论如何你都可以使转置对角线中的元素和0以上(包括对角线)。

此外,虽然很小,但请将第一个取消挤压操作移到嵌套 for 循环之外。

希望对您有所帮助。

关于python-3.x - 场感知分解的向量化实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56860236/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com