gpt4 book ai didi

python - 从 block 对角线 PyTorch 张量中提取 block

转载 作者:行者123 更新时间:2023-12-03 23:03:22 28 4
gpt4 key购买 nike

我有一个形状为 (m*n, m*n) 的张量,我想提取一个大小为 (n, m*n) 的张量,其中包含对角线上大小为 n*n 的 m 个块。例如:

>>> a
tensor([[1, 2, 0, 0],
[3, 4, 0, 0],
[0, 0, 5, 6],
[0, 0, 7, 8]])
我想要一个函数 extract(a, m, n)这将输出:
>>> extract(a, 2, 2)
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
我想过使用某种切片,因为块可以表示为:
>>> for i in range(m):
... print(a[i*m: i*m + n, i*m: i*m + n])
tensor([[1, 2],
[3, 4]])
tensor([[5, 6],
[7, 8]])

最佳答案

您可以利用 reshape和切片:

import torch
import numpy as np

def extract(a, m, n):
s=(range(m), np.s_[:], range(m), np.s_[:]) # the slices of the blocks
a.reshape(m, n, m, n)[s] # reshaping according to blocks and slicing
return a.reshape(m*n, n).T # reshape to desired output format
例子:
a = torch.arange(36).reshape(6,6)
extract(a, 3, 2)

tensor([[ 0, 6, 14, 20, 28, 34],
[ 1, 7, 15, 21, 29, 35]])

extract(a, 2, 3)

tensor([[ 0, 6, 12, 21, 27, 33],
[ 1, 7, 13, 22, 28, 34],
[ 2, 8, 14, 23, 29, 35]])

关于python - 从 block 对角线 PyTorch 张量中提取 block ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64195225/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com