gpt4 book ai didi

tensorflow - 如何显式广播张量以匹配 tensorflow 中的另一个形状?

转载 作者:行者123 更新时间:2023-12-03 23:35:17 25 4
gpt4 key购买 nike

我有三个张量,A, B and C在 tensorflow 中,AB都是形状 (m, n, r) , C是形状为 (m, n, 1) 的二元张量.

我想根据 C 的值从 A 或 B 中选择元素.明显的工具是tf.select ,但是它没有广播语义,所以我需要首先明确地广播 C形状与 A 和 B 相同。

这将是我第一次尝试如何做到这一点,但它不喜欢我将张量( tf.shape(A)[2] )混合到形状列表中。

import tensorflow as tf
A = tf.random_normal([20, 100, 10])
B = tf.random_normal([20, 100, 10])
C = tf.random_normal([20, 100, 1])
C = tf.greater_equal(C, tf.zeros_like(C))

C = tf.tile(C, [1,1,tf.shape(A)[2]])
D = tf.select(C, A, B)

这里的正确方法是什么?

最佳答案

编辑:在 0.12rc0 之后的所有 TensorFlow 版本中,问题中的代码都可以直接运行。 TensorFlow 会自动将张量和 Python 数字叠加到张量参数中。下面的解决方案使用 tf.pack()仅在 0.12rc0 之前的版本中需要。请注意 tf.pack()更名为 tf.stack() 在 TensorFlow 1.0 中。

您的解决方案非常接近工作。您应该替换该行:

C = tf.tile(C, [1,1,tf.shape(C)[2]])

...具有以下内容:
C = tf.tile(C, tf.pack([1, 1, tf.shape(A)[2]]))

(问题的原因是 TensorFlow 不会将张量列表和 Python 文字隐式转换为张量。 tf.pack() 获取张量列表,因此它将转换其输入中的每个元素( 11tf.shape(C)[2] ) 到张量。由于每个元素都是标量,结果将是一个向量。)

关于tensorflow - 如何显式广播张量以匹配 tensorflow 中的另一个形状?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34362193/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com