gpt4 book ai didi

python - 是否可以使用 numpy 压缩除 N 维以外的所有维度?

转载 作者:太空宇宙 更新时间:2023-11-03 23:56:36 26 4
gpt4 key购买 nike

我想知道是否有一种方法可以将大小为 1 的所有维度压缩到一个数组中,并且不压缩 N 个维度(即使这些维度的大小为 1)。

为什么?假设我有一个接收一个数组的函数,它返回数组及其转置的矩阵乘积,但数组的形状是未知的(最大 2 个 dims,大小 > 1,但可以有更多大小为 1 的 dims)

可能的矩阵形状示例:

A.shape -> (M,N)
B.shape -> (M,N,1[...,1])
C.shape -> (M,1[...,1])

我希望始终具有 A 的形状 (ndim = 2) 以便执行矩阵乘积。

我可以使用 np.squeeze(X),仅此而已,但在 C 的情况下,这会导致以下问题:

import numpy as np

def my_function(arr):
arr = np.squeeze(arr)
return np.dot(arr, arr.transpose())

x = np.arange(1, 6) # shape (5,)
x = x.reshape((x.size, 1, 1)) # shape (5, 1, 1)
y = my_function(x)
print(y)
# Actual y.shape -> () [is a number]
# Expected y.shape -> (5, 5) [matrix]

我希望 np.squeeze() 函数有一个参数 axis_to_keep。你知道是否有办法轻松实现这一目标?我知道一些方法,但我需要最有效的方法,因为我必须多次执行这些操作。

最佳答案

使用 axes_to_keep 参数进行挤压

这是一个用于通用 n-dim 数组的请求 axes_to_keep 参数,可将这些轴保持在原位 -

def squeeze_generic(a, axes_to_keep):
out_s = [s for i,s in enumerate(a.shape) if i in axes_to_keep or s!=1]
return a.reshape(out_s)

样本运行-

In [105]: a = np.random.rand(3,4,5,1,1,6,1)

In [106]: squeeze_generic(a, axes_to_keep=(3,4)).shape
Out[106]: (3, 4, 5, 1, 1, 6)

In [107]: squeeze_generic(a, axes_to_keep=(3,4,6)).shape
Out[107]: (3, 4, 5, 1, 1, 6, 1)

# For cases when axes_to_keep lists axes that aren't singleton
In [108]: squeeze_generic(a, axes_to_keep=(0,1)).shape
Out[108]: (3, 4, 5, 6)

解决您的问题以保留前两个轴

因此,要解决您保留前两个轴的特定情况,它将是 -

squeeze_generic(a, axes_to_keep=range(2))

让我们看一下示例案例 -

In [55]: a = np.random.rand(3,5)

In [56]: squeeze_generic(a, axes_to_keep=range(2)).shape
Out[56]: (3, 5)

In [57]: a = np.random.rand(3,5,1)

In [58]: squeeze_generic(a, axes_to_keep=range(2)).shape
Out[58]: (3, 5)

In [59]: a = np.random.rand(3,1)

In [60]: squeeze_generic(a, axes_to_keep=range(2)).shape
Out[60]: (3, 1)

如果保证第二个轴之后的所有轴都是单轴(长度轴=1)(如果有的话),那么一个简单的 reshape 也可以完成这项工作-

a.reshape(a.shape[0],-1)

关于python - 是否可以使用 numpy 压缩除 N 维以外的所有维度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57472104/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com