gpt4 book ai didi

python - 处理 jax numpy 数组中的不同形状(jit 兼容)

转载 作者:行者123 更新时间:2023-12-05 02:31:03 25 4
gpt4 key购买 nike

重要说明:我需要这里的所有内容都与 jit 兼容,否则我的问题就微不足道了:)

我有一个 jax numpy 数组,例如:

a = jnp.array([1,5,3,4,5,6,7,2,9])

首先,我根据一个值对其进行过滤,假设我只保留 < 5

的值
a = jnp.where((a < 5), x=a, y=jnp.nan)
# a is now [ 1. nan 3. 4. nan nan nan 2. nan]

我只想保留非 nan 值:[ 1. 3. 4. 2.] 然后我将使用此数组进行其他操作。

但更重要的是,在我的程序执行期间,这段代码将执行多次,阈值会发生变化(即它不会总是 5)。

因此,最终数组的形状也会发生变化。这是我的 jit 编译问题,我不知道如何使其与 jit 兼容,因为形状取决于有多少元素符合阈值条件。

最佳答案

JAX 的 JIT 目前与动态(数据相关)形状的数组不兼容,因此无法完成您的问题。

有一些关于在 JAX 转换(如 JIT)中处理动态形状的实验性工作正在进行中(请参阅 https://github.com/google/jax/pull/9335),但我不确定它何时可以使用。

通常的解决方法是根据具有填充值的静态形状数组重新表达您的计算;例如,你可以使用这样的东西:

a = jnp.where((a < 5), size=len(a), fill_value=np.nan)

这将创建一个长度与 a 相同的数组,在前面有非 nan 值,并在末尾填充 nan 值。

关于python - 处理 jax numpy 数组中的不同形状(jit 兼容),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71692885/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com