gpt4 book ai didi

python - 从一个 100x100 的 pytorch 张量获得一个 10x10 的补丁,并在边界周围环绕圆环样式

转载 作者:行者123 更新时间:2023-12-03 23:00:39 24 4
gpt4 key购买 nike

如何从 100x100 pytorch 张量获得 10x10 的补丁,并附加约束,即如果补丁超出阵列边界,则它环绕边缘(好像阵列是环面,顶部连接到底部,左侧连接到右侧)?
我写了这段代码来完成这项工作,我正在寻找更优雅、更高效和更清晰的东西:

def shift_matrix(a, distances) -> Tensor:
x, y = distances
a = torch.cat((a[x:], a[0:x]), dim=0)
a = torch.cat((a[:, y:], a[:, :y]), dim=1)
return a

def randomly_shift_matrix(a) -> Tensor:
return shift_matrix(a, np.random.randint(low = 0, high = a.size()))

def random_patch(a, size) -> Tensor:
full_shifted_matrix = randomly_shift_matrix(a)
return full_shifted_matrix[0:size[0], 0:size[1]]
我觉得带有负索引切片的东西应该可以工作。不过我还没找到。
您可以 see the code in google colab here .

最佳答案

您正在寻找 torch.roll

def random_patch(a, size) -> Tensor:
shifts = np.random.randint(low = 0, high = a.size())
return torch.roll(a, shifts=shifts, dims=(0, 1))[:size[0], :size[1]]

关于python - 从一个 100x100 的 pytorch 张量获得一个 10x10 的补丁,并在边界周围环绕圆环样式,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66006277/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com