gpt4 book ai didi

jit - 是否可以 jit 使用 jax.numpy.unique 的函数?

转载 作者:行者123 更新时间:2023-12-05 06:00:17 28 4
gpt4 key购买 nike

以下代码无效:

def get_unique(arr):
return jnp.unique(arr)

get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))

有关jnp.unique的使用的错误信息:

FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()

documentation on sharp bits说明如果内部数组的形状取决于参数值,则 jit 不起作用。这正是这里的情况。

根据文档,一个潜在的解决方法是指定静态参数。但这不适用于我的情况。几乎每个函数调用的参数都会改变。我已将我的代码拆分为一个预处理步骤,该步骤执行诸如此 jnp.unique 的计算,以及一个可以 jitted 的计算步骤。

但我还是想问一下,是否有一些我不知道的解决方法?

最佳答案

不,由于您提到的原因,目前无法在非静态值上使用 jnp.unique

在类似的情况下,JAX 有时会添加额外的参数,这些参数可用于指定输出的静态大小(例如,jax.numpy.nonzero 中的 size 参数)但目前没有实现类似的功能对于 jnp.unique。如果那是您想要的,那么值得提交 feature request .

关于jit - 是否可以 jit 使用 jax.numpy.unique 的函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67739742/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com