gpt4 book ai didi

python - 使用 numpy 和 jax 进行非传递子类化

转载 作者:行者123 更新时间:2023-12-05 02:48:05 27 4
gpt4 key购买 nike

我的问题很简单:

>>> isinstance(x, jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray, numpy.ndarray)
True
>>> isinstance(x, numpy.ndarray)
False

?

现在我要扯远了,所以 SE 会接受我的合理问题。

最佳答案

出现这种情况的原因是因为 jax.numpy.ndarray 使用元类覆盖了实例检查:

class _ArrayMeta(type(np.ndarray)):  # type: ignore
"""Metaclass for overriding ndarray isinstance checks."""

def __instancecheck__(self, instance):
try:
return isinstance(instance.aval, _arraylike_types)
except AttributeError:
return isinstance(instance, _arraylike_types)

class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int

def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")

( view source )

你的代码返回它所做的事情的原因是因为你有一个 x 值,它不是 numpy.ndarray 的实例,但是这个 __instancecheck__ 方法返回 true。

为什么在 JAX 中使用这种诡计?好吧,出于 JIT 编译、自动微分和其他转换的目的,JAX 使用称为 tracers 的替代对象,这些对象看起来和行为都像一个数组,尽管实际上并不是一个数组。这种对实例检查的覆盖是 JAX 用来使此类跟踪起作用的技巧之一。

关于python - 使用 numpy 和 jax 进行非传递子类化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64654717/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com