gpt4 book ai didi

python - 在 jax vmap 函数中调试数组

转载 作者:行者123 更新时间:2023-12-04 17:17:24 27 4
gpt4 key购买 nike

亲爱的 jax 专家,我需要您的帮助。
这是一个工作示例(我已经按照建议简化了我的代码,虽然我不是 jax 方面的专家,也不是 Python 方面的专家,无法猜测 vmap 所涉及的机制的核心是什么)

def jax_kernel(rng_key, logpdf, position, log_prob):
key, subkey = jax.random.split(rng_key)
move_proposals = jax.random.normal(key, shape=position.shape)* 0.1
proposal = position + move_proposals
proposal_log_prob = logpdf(proposal)
return proposal, proposal_log_prob

def jax_sampler(rng_key, n_samples, logpdf, initial_position):

def mh_update(i, state):
key, positions, log_prob = state
_, key = jax.random.split(key)
print(f"mh_update: positions[{i-1}]:",jnp.asarray(positions[i-1]))
new_position, new_log_prob = jax_kernel(key,logpdf,positions[i-1],log_prob)

positions=positions.at[i].set(new_position)
return (key, positions, new_log_prob)

# all positions structure should be set before lax.fori_loop
print("initial_position shape:",initial_position.shape)
all_positions = jnp.zeros((n_samples,)+initial_position.shape)
all_positions=all_positions.at[0,0].set(1.)
all_positions=all_positions.at[0,1].set(2.)
all_positions=all_positions.at[0,2].set(2.)
print("all_positions init:",all_positions.shape)
logp = logpdf(all_positions[0])

# use of a for-loop to be able to debug mh_update instead of a jax.fori_loop
initial_state = (rng_key,all_positions, logp)
val = initial_state
for i in range(1, n_samples):
val = mh_update(i, val)
rng_key, all_positions, log_prob = val
# return all the positions of the parameters (n_chains, n_samples, n_dim)
return all_positions

def func(par):
xi = jnp.asarray(sci_stats.uniform.rvs(size=10))
val = xi*par[1]+par[0]
return jnp.sum(jax.scipy.stats.norm.logpdf(x=val,loc=yi,scale=par[2]))


n_dim = 3 # number of parameters ie. (a,b,s)
n_samples = 5 # number of samples per chain
n_chains = 4 # number of MCMC chains
rng_key = jax.random.PRNGKey(42)
rng_keys = jax.random.split(rng_key, n_chains)
initial_position = jnp.ones((n_dim, n_chains))
print("main initial_position shape",initial_position.shape)
run = jax.vmap(jax_sampler, in_axes=(0, None, None, 1), out_axes=0)
all_positions = run(rng_keys,n_samples,lambda p: func(p),initial_position)
print("all_positions:",all_positions)
然后我的问题涉及维度演化 print(f"mh_update: positions[{i-1}]:",jnp.asarray(positions[i-1])) 。我不明白为什么 positions[i-1] 以维度 n_dim 开头,然后切换到 n_chains x n_dim
预先感谢您的意见?
这是完整的输出:
main initial_position shape (3, 4)
initial_position shape: (3,)
all_positions init: (5, 3)
mh_update: positions[0]: [1. 2. 2.]
mh_update: positions[1]: Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[0.9354116 , 1.7876872 , 1.8443539 ],
[0.9844745 , 2.073029 , 1.9511036 ],
[0.98202926, 2.0109322 , 2.094176 ],
[0.9536771 , 1.9731759 , 2.093319 ]], dtype=float32)
batch_dim = 0
mh_update: positions[2]: Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[1.0606856, 1.6707807, 1.8377957],
[1.0465866, 1.9754674, 1.7009288],
[1.1107644, 2.0142047, 2.190575 ],
[1.0089972, 1.9953227, 1.996874 ]], dtype=float32)
batch_dim = 0
mh_update: positions[3]: Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[1.0731456, 1.644405 , 2.1343162],
[1.0599504, 2.0121546, 1.6867112],
[1.0585173, 1.9661485, 2.1573594],
[1.1213307, 1.9335203, 1.9683584]], dtype=float32)
batch_dim = 0
all_positions: [[[1. 2. 2. ]
[0.9354116 1.7876872 1.8443539 ]
[1.0606856 1.6707807 1.8377957 ]
[1.0731456 1.644405 2.1343162 ]
[1.0921828 1.5742197 2.058759 ]]

[[1. 2. 2. ]
[0.9844745 2.073029 1.9511036 ]
[1.0465866 1.9754674 1.7009288 ]
[1.0599504 2.0121546 1.6867112 ]
[1.0835105 2.0051234 1.4766487 ]]

[[1. 2. 2. ]
[0.98202926 2.0109322 2.094176 ]
[1.1107644 2.0142047 2.190575 ]
[1.0585173 1.9661485 2.1573594 ]
[1.1728328 1.981367 2.180744 ]]

[[1. 2. 2. ]
[0.9536771 1.9731759 2.093319 ]
[1.0089972 1.9953227 1.996874 ]
[1.1213307 1.9335203 1.9683584 ]
[1.1148386 1.9598911 2.1721165 ]]]

最佳答案

在第一次迭代中,您打印在 vmapped 函数中构建的具体数组。它是一个 float32 形状的数组 (3,)
在第一次迭代之后,您已经通过对 vmap 输入的操作构造了一个新数组。当您 vmap 这样的输入时,JAX 将您的输入数组替换为跟踪器,该跟踪器是您输入的抽象表示;打印值如下所示:

Traced<ShapedArray(float32[3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[1.0731456, 1.644405 , 2.1343162],
[1.0599504, 2.0121546, 1.6867112],
[1.0585173, 1.9661485, 2.1573594],
[1.1213307, 1.9335203, 1.9683584]], dtype=float32)
float32[3] 表示此跟踪器表示一个形状为 (3,) 的 float32 值数组:也就是说,它仍然具有与第一次迭代中相同的类型和形状。但在这种情况下,它不是一个包含三个元素的具体数组,它是一个批处理跟踪器,表示 vmapped 输入的每次迭代。 vmap 转换的强大之处在于,JAX 在一次传递中有效地跟踪了 vmapped 计算的 所有 隐含迭代:在跟踪器表示中, val 的行有效地向您显示了所有 vmapped 迭代的中间值。
要进一步了解 JAX 跟踪的工作原理,请阅读 JAX 文档中的 How To Think In JAX

关于python - 在 jax vmap 函数中调试数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68380291/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com