gpt4 book ai didi

python - 沿给定轴将 numpy ndarray 与一维数组相乘

转载 作者:太空狗 更新时间:2023-10-29 21:28:46 27 4
gpt4 key购买 nike

看来我迷失在一些可能很愚蠢的事情中。 我有一个 n 维 numpy 数组,我想将它与沿某个维度(可以改变!)的向量(一维数组)相乘。例如,假设我想沿第一个数组的轴 0 将一个二维数组乘以一个一维数组,我可以这样做:

a=np.arange(20).reshape((5,4))
b=np.ones(5)
c=a*b[:,np.newaxis]

很简单,但我想将这个想法扩展到 n 维(对于 a,而 b 始终是 1d)和任何轴。换句话说,我想知道如何在正确的位置生成带有 np.newaxis 的切片。假设 a 是 3d,我想沿 axis=1 相乘,我想生成正确给出的切片:

c=a*b[np.newaxis,:,np.newaxis]

即给定 a 的维数(比如 3)和我要相乘的轴(比如 axis=1),我如何生成和传递切片:

np.newaxis,:,np.newaxis

谢谢。

最佳答案

解决方案代码-

import numpy as np

# Given axis along which elementwise multiplication with broadcasting
# is to be performed
given_axis = 1

# Create an array which would be used to reshape 1D array, b to have
# singleton dimensions except for the given axis where we would put -1
# signifying to use the entire length of elements along that axis
dim_array = np.ones((1,a.ndim),int).ravel()
dim_array[given_axis] = -1

# Reshape b with dim_array and perform elementwise multiplication with
# broadcasting along the singleton dimensions for the final output
b_reshaped = b.reshape(dim_array)
mult_out = a*b_reshaped

步骤演示的示例运行 -

In [149]: import numpy as np

In [150]: a = np.random.randint(0,9,(4,2,3))

In [151]: b = np.random.randint(0,9,(2,1)).ravel()

In [152]: whos
Variable Type Data/Info
-------------------------------
a ndarray 4x2x3: 24 elems, type `int32`, 96 bytes
b ndarray 2: 2 elems, type `int32`, 8 bytes

In [153]: given_axis = 1

现在,我们想沿given axis = 1 执行逐元素乘法。让我们创建 dim_array:

In [154]: dim_array = np.ones((1,a.ndim),int).ravel()
...: dim_array[given_axis] = -1
...:

In [155]: dim_array
Out[155]: array([ 1, -1, 1])

最后, reshape b 并执行逐元素乘法:

In [156]: b_reshaped = b.reshape(dim_array)
...: mult_out = a*b_reshaped
...:

再次检查whos 信息并特别注意b_reshapedmult_out:

In [157]: whos
Variable Type Data/Info
---------------------------------
a ndarray 4x2x3: 24 elems, type `int32`, 96 bytes
b ndarray 2: 2 elems, type `int32`, 8 bytes
b_reshaped ndarray 1x2x1: 2 elems, type `int32`, 8 bytes
dim_array ndarray 3: 3 elems, type `int32`, 12 bytes
given_axis int 1
mult_out ndarray 4x2x3: 24 elems, type `int32`, 96 bytes

关于python - 沿给定轴将 numpy ndarray 与一维数组相乘,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/30031828/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com