gpt4 book ai didi

python - n 批处理 > 1 的多元正态分布

转载 作者:太空宇宙 更新时间:2023-11-03 20:34:47 26 4
gpt4 key购买 nike

我试图概括 How to use a MultiVariateNormal distribution in the latest version of Tensorflow 中给出的示例服从二维正态分布,但具有多个批处理。当我运行以下命令时:

from tensorflow_probability import distributions as tfd
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

mu = [[1, 2],
[-1,-2]]

cov = [[1, 3./5],
[3./5, 2]]

cov = [cov, cov] # for demonstration purpose, use same cov for both batches

mvn = tfd.MultivariateNormalFullCovariance(
loc=mu,
covariance_matrix=cov)

# generate the pdf
X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
idx = tf.concat([tf.reshape(X, [-1, 1]), tf.reshape(Y,[-1,1])], axis =1)
prob = tf.reshape(mvn.prob(idx), tf.shape(X))

我收到形状不兼容错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [3600,2] vs. [2,2] [Op:Sub] name: MultivariateNormalFullCovariance/log_prob/affine_linear_operator/inverse/sub/

我对文档 ( https://www.tensorflow.org/api_docs/python/tf/contrib/distributions/MultivariateNormalFullCovariance ) 的理解是,要计算 pdf,需要一个 [n_observation, n_dimensions] 张量(本例中就是这种情况:idx.shape = TensorShape([维度(3600),维度(2)]))。我的数学算错了吗?

最佳答案

您需要将批处理轴添加到 idx张量位于倒数第二个位置,因为 60x60 无法针对 mvn.batch_shape 进行广播的(2,) .

# TF/TFP Imports
!pip install --quiet tfp-nightly tf-nightly
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions

mu = [[1, 2],
[-1, -2]]

cov = [[1, 3./5],
[3./5, 2]]

cov = [cov, cov] # for demonstration purpose, use same cov for both batches

mvn = tfd.MultivariateNormalFullCovariance(
loc=mu, covariance_matrix=cov)
print(mvn.batch_shape, mvn.event_shape)

# generate the pdf
X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
print(X.shape)
idx = tf.stack([X, Y], axis=-1)[..., tf.newaxis, :]
print(idx.shape)

probs = mvn.prob(idx)
print(probs.shape)

输出:

(2,) (2,)   # mvn.batch_shape, mvn.event_shape
(60, 60) # X.shape
(60, 60, 1, 2) # idx.shape == X.shape + (1 "broadcast against batch", 2 "event")
(60, 60, 2) # probs.shape == X.shape + (2 "mvn batch shape")

关于python - n 批处理 > 1 的多元正态分布,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57238774/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com