gpt4 book ai didi

python - 如何从 Pytorch 张量中删除每个填充为零的列?

转载 作者:行者123 更新时间:2023-12-04 10:57:42 24 4
gpt4 key购买 nike

我有一个 pytorch 张量 A像下面这样:

A = 
tensor([[ 4, 3, 3, ..., 0, 0, 0],
[ 13, 4, 13, ..., 0, 0, 0],
[707, 707, 4, ..., 0, 0, 0],
...,
[ 7, 7, 7, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[195, 195, 195, ..., 0, 0, 0]], dtype=torch.int32)

我想要:
  • 识别所有条目都等于 0 的所有列
  • 仅删除所有条目都等于 0 的列

  • 我可以想象这样做:
    zero_list = []
    for j in range(A.size()[1]):
    if torch.sum(A[:,j]) == 0:
    zero_list = zero_list.append(j)

    标识其元素只有 0 的列
    但我不确定如何从原始张量中删除这些填充为 0 的列。

    如何根据索引号从 pytorch 张量中删除零列?

    谢谢,

    最佳答案

    索引要保留的列而不是要删除的列更有意义。

    valid_cols = []
    for col_idx in range(A.size(1)):
    if not torch.all(A[:, col_idx] == 0):
    valid_cols.append(col_idx)
    A = A[:, valid_cols]

    或者更神秘一点
    valid_cols = [col_idx for col_idx, col in enumerate(torch.split(A, 1, dim=1)) if not torch.all(col == 0)]
    A = A[:, valid_cols]

    关于python - 如何从 Pytorch 张量中删除每个填充为零的列?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59076228/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com