gpt4 book ai didi

jit - 如何使用 JAX 打印

转载 作者:行者123 更新时间:2023-12-03 08:08:15 25 4
gpt4 key购买 nike

我有一个 JAX bool 数组,想要打印一条与 True 总和相结合的语句:

import jax
import jax.numpy as jnp
from jax.experimental.host_callback import id_print

@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
id_print(jnp.sum(mask_cp))

overlaps_jax()

mask_cp中有 5 个True;我想打印为:

With jax accelerator
There are 5 true bools

因为这个函数是jitted ,我尝试使用 id_print 打印此内容,但我不能。 id_print(jnp.sum(mask_cp)) 将打印 5,但我无法将其与字符串一起使用。我已经尝试过以下方法:

id_print(jnp.sum(mask_cp))
# print:
# 5

id_print("\nWith jax accelerator\nThere are " + jnp.sum(mask_cp) + " true bools\n")
# error:
# TypeError: can only concatenate str (not "DynamicJaxprTracer") to str

print("\nWith jax accelerator\nThere are {} true bools\n".format(jnp.sum(mask_cp)))
# print:
# With jax accelerator
# There are Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> true bools

如何在此代码中打印此类语句?

最佳答案

请注意,id_print 是实验性的,其 API 和功能可能会发生变化。也就是说,我不相信 id_print 能够添加这样的文本,但您可以通过更通用的 host_callback.call 来做到这一点:

import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call

@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
call(lambda x: print(f"There are {x} true bools"), jnp.sum(mask_cp))

overlaps_jax()

输出为

There are 5 true bools

关于jit - 如何使用 JAX 打印,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71548823/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com