gpt4 book ai didi

python - numpy ndarray 的子类不能按预期工作

转载 作者:太空狗 更新时间:2023-10-29 23:59:03 25 4
gpt4 key购买 nike

`大家好。

我发现在对 ndarray 进行子类化时有一个奇怪的行为。

import numpy as np

class fooarray(np.ndarray):
def __new__(cls, input_array, *args, **kwargs):
obj = np.asarray(input_array).view(cls)
return obj

def __init__(self, *args, **kwargs):
return

def __array_finalize__(self, obj):
return

a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)

a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)

print a_sum.ndim #1
print b_sum.ndim #2

如您所见,keepdims 参数对我的子类 fooarray 不起作用。它失去了它的一根轴。我怎么不能避免这个问题?或者更一般地说,我怎样才能正确地子类化 numpy ndarray?

最佳答案

np.sum 可以接受各种对象作为输入:例如,不仅是 ndarrays,还有列表、生成器、np.matrixkeepdims 参数显然对列表或生成器没有意义。它也不适合 np.matrix 实例,因为 np.matrix 总是有 2 个维度。如果您查看 np.matrix.sum 的调用签名,您会发现它的 sum 方法没有 keepdims 参数:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)

所以 ndarray 的一些子类可能有 sum 方法,但没有 keepdims 参数。不幸的是,这违反了 Liskov substitution principle以及你遇到的陷阱的起源。

现在,如果您查看 the source code for np.sum ,您会看到它是一个委托(delegate)函数,它试图根据第一个参数的类型确定要做什么。

如果第一个参数的类型不是ndarray,它会丢弃keepdims参数。这样做是因为将 keepdims 参数传递给 np.matrix.sum 会引发异常。

因为 np.sum 试图以最通用的方式进行委托(delegate),而不是对 ndarray 的子类可能采用的参数做任何假设,它删除了 keepdims 参数时传递一个 fooarray

解决方法是不使用 np.sum,而是调用 a.sum。无论如何,这更直接,因为 np.sum 只是一个委托(delegate)函数。

import numpy as np


class fooarray(np.ndarray):
def __new__(cls, input_array, *args, **kwargs):
obj = np.asarray(input_array, *args, **kwargs).view(cls)
return obj

a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)

a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)

print(a_sum.ndim) # 2
print(b_sum.ndim) # 2

关于python - numpy ndarray 的子类不能按预期工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/23735015/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com