gpt4 book ai didi

python - 如何防止 Numpy 拆分类似数组的对象

转载 作者:行者123 更新时间:2023-11-28 17:46:34 25 4
gpt4 key购买 nike

如果我考虑以下简单类:

class Quantity(object):

def __init__(self, value, unit):
self.unit = unit
self.value = value

def __getitem__(self, key):
return Quantity(self.value[key], unit=self.unit)

def __len__(self):
return len(self.value)

并创建一个实例:

import numpy as np
q = Quantity(np.array([1,2,3]), 'degree')
print(repr(np.array(q)))

然后,如果我将这个对象传递给 Numpy,它会将对象拆分为 3 个 Quantity 实例的对象数组:

array([<__main__.Quantity object at 0x1073a0d50>,
<__main__.Quantity object at 0x1073a0d90>,
<__main__.Quantity object at 0x1073a0dd0>], dtype=object)

这是由于 __len____getitem__ 方法的存在 - 如果我删除其中任何一个,则对象不会被拆分:

array(<__main__.Quantity object at 0x110a4e610>, dtype=object)

我仍然想保留 __len____getitem__,但是有没有办法防止 Numpy 拆分对象?

编辑:除了将 Quantity 设为 ndarray 子类之外,我对其他的解决方案很感兴趣

最佳答案

这是您要找的吗?

class Quantity(object):

def __init__(self, value, unit):
self.unit = unit
self.value = value

def __getitem__(self, key):
return Quantity(self.value[key], unit=self.unit)

def __len__(self):
return len(self.value)

def __array__(self):
return self.value

np.array 使用__array__ 方法

In [11]: q
Out[11]: <__main__.Quantity at 0x1042bdf90>

In [12]: np.array(q)
Out[12]: array([ 1., 2., 3.])

In [13]: print(repr(np.array(q)))
array([ 1., 2., 3.])

In [14]: len(q)
Out[14]: 3

In [15]: q[1]
Out[15]: <__main__.Quantity at 0x1042bdd50>

In [16]: q[0]
Out[16]: <__main__.Quantity at 0x1042bdd90>

In [17]: q[0].value
Out[17]: 1.0

关于python - 如何防止 Numpy 拆分类似数组的对象,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/17444657/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com