gpt4 book ai didi

python - 在 PyTorch 中用矢量替换对角线元素

转载 作者:太空宇宙 更新时间:2023-11-04 00:20:05 25 4
gpt4 key购买 nike

我一直在到处寻找与 PyTorch 等价的东西,但找不到任何东西。

L_1 = np.tril(np.random.normal(scale=1., size=(D, D)), k=0)
L_1[np.diag_indices_from(L_1)] = np.exp(np.diagonal(L_1))

我想没有办法使用 Pytorch 以如此优雅的方式替换对角线元素。

最佳答案

有更简单的方法

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

事实上,我们必须自己生成对角线索引。

使用示例:

dest_matrix = torch.randint(10, (3, 3))
source_vector = torch.randint(100, 200, (len(dest_matrix), ))
print('dest_matrix:\n', dest_matrix)
print('source_vector:\n', source_vector)

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

print('result:\n', dest_matrix)

# dest_matrix:
# tensor([[3, 2, 5],
# [0, 3, 5],
# [3, 1, 1]])
# source_vector:
# tensor([182, 169, 147])
# result:
# tensor([[182, 2, 5],
# [ 0, 169, 5],
# [ 3, 1, 147]])

万一dest_matrix不是正方形你得拍min(dest_matrix.size())而不是 len(dest_matrix)range()

不如numpy优雅但这不需要存储新的索引矩阵。

是的,这会保留渐变

关于python - 在 PyTorch 中用矢量替换对角线元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49429147/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com