gpt4 book ai didi

python - 通过掩码选择 numpy 数组的元素并保留原始尺寸

转载 作者:行者123 更新时间:2023-12-05 06:06:21 24 4
gpt4 key购买 nike

您好,我有以下数据

ids = np.concatenate([1.0 * np.ones(shape=(4, 9,)), 
2.0 * np.ones(shape=(4, 3,))], axis=1)

logits = np.random.normal(size=(4, 9 + 3, 256))

现在我只想获取 id 为 1.0 的 numpy 数组,并且我想获取大小为 (4,9, 256)

的数组

我尝试了 logits[ids == 1.0, :] 但我得到了 (36, 256)如何在不连接前两个维度的情况下进行切片?

当前尺寸仅为示例尺寸,我正在寻找通用解决方案。

最佳答案

您的问题似乎假设每一行都有相同数量的非零条目;在这种情况下,您通常可以像这样解决问题:

mask = (ids == 1)
num_per_row = mask.sum(1)

# same number of entries per row is required
assert np.all(num_per_row == num_per_row[0])

result = logits[mask].reshape(logits.shape[0], num_per_row[0], logits.shape[2])

print(result.shape)
# (4, 9, 256)

关于python - 通过掩码选择 numpy 数组的元素并保留原始尺寸,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65762614/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com