gpt4 book ai didi

tensorflow - 如何使用 tf.cond 进行批处理

转载 作者:行者123 更新时间:2023-12-04 04:02:49 31 4
gpt4 key购买 nike

我想用tf.cond(pred, fn1, fn2, name=None)用于条件分支。假设我有两个张量:x, y .每个张量都是一批 0/1,我想使用这个张量压缩 x < y作为来源
tf.cond pred争论:

pred: A scalar determining whether to return the result of fn1 or fn2.


但是,如果我正在处理批处理,那么看起来我需要遍历图中的源张量并为批处理中的每个项目制作切片并为每个项目应用 tf.cond。至于我,看起来很可疑。为什么 tf.cond 不接受批处理而只接受标量?您能否建议将其与批处理一起使用的正确方法是什么?

最佳答案

tf.where听起来像您想要的:张量之间的矢量化选择。
tf.cond是一个控制流修饰符:它决定执行哪些操作,因此很难想到有用的批处理语义。

我们还可以将这些操作组合在一起:根据条件进行切片并将这些切片传递给两个分支的操作​​。

import tensorflow as tf
from tensorflow.python.util import nest

def slicing_where(condition, full_input, true_branch, false_branch):
"""Split `full_input` between `true_branch` and `false_branch` on `condition`.

Args:
condition: A boolean Tensor with shape [B_1, ..., B_N].
full_input: A Tensor or nested tuple of Tensors of any dtype, each with
shape [B_1, ..., B_N, ...], to be split between `true_branch` and
`false_branch` based on `condition`.
true_branch: A function taking a single argument, that argument having the
same structure and number of batch dimensions as `full_input`. Receives
slices of `full_input` corresponding to the True entries of
`condition`. Returns a Tensor or nested tuple of Tensors, each with batch
dimensions matching its inputs.
false_branch: Like `true_branch`, but receives inputs corresponding to the
false elements of `condition`. Returns a Tensor or nested tuple of Tensors
(with the same structure as the return value of `true_branch`), but with
batch dimensions matching its inputs.
Returns:
Interleaved outputs from `true_branch` and `false_branch`, each Tensor
having shape [B_1, ..., B_N, ...].
"""
full_input_flat = nest.flatten(full_input)
true_indices = tf.where(condition)
false_indices = tf.where(tf.logical_not(condition))
true_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
for input_tensor in full_input_flat])
false_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
for input_tensor in full_input_flat])
true_outputs = true_branch(true_branch_inputs)
false_outputs = false_branch(false_branch_inputs)
nest.assert_same_structure(true_outputs, false_outputs)
def scatter_outputs(true_output, false_output):
batch_shape = tf.shape(condition)
scattered_shape = tf.concat(
[batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
0)
true_scatter = tf.scatter_nd(
indices=tf.cast(true_indices, tf.int32),
updates=true_output,
shape=scattered_shape)
false_scatter = tf.scatter_nd(
indices=tf.cast(false_indices, tf.int32),
updates=false_output,
shape=scattered_shape)
return true_scatter + false_scatter
result = nest.pack_sequence_as(
structure=true_outputs,
flat_sequence=[
scatter_outputs(true_single_output, false_single_output)
for true_single_output, false_single_output
in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
return result

一些例子:
vector_test = slicing_where(
condition=tf.equal(tf.range(10) % 2, 0),
full_input=tf.range(10, dtype=tf.float32),
true_branch=lambda x: 0.2 + x,
false_branch=lambda x: 0.1 + x)

cross_range = (tf.range(10, dtype=tf.float32)[:, None]
* tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
condition=tf.equal(tf.range(10) % 3, 0),
full_input=cross_range,
true_branch=lambda x: -x,
false_branch=lambda x: x + 0.1)

with tf.Session():
print(vector_test.eval())
print(matrix_test.eval())

打印:
[ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
6.19999981 7.0999999 8.19999981 9.10000038]
[[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[ 0.1 1.10000002 2.0999999 3.0999999 4.0999999
5.0999999 6.0999999 7.0999999 8.10000038 9.10000038]
[ 0.1 2.0999999 4.0999999 6.0999999 8.10000038
10.10000038 12.10000038 14.10000038 16.10000038 18.10000038]
[ 0. -3. -6. -9. -12. -15.
-18. -21. -24. -27. ]
[ 0.1 4.0999999 8.10000038 12.10000038 16.10000038
20.10000038 24.10000038 28.10000038 32.09999847 36.09999847]
[ 0.1 5.0999999 10.10000038 15.10000038 20.10000038
25.10000038 30.10000038 35.09999847 40.09999847 45.09999847]
[ 0. -6. -12. -18. -24. -30.
-36. -42. -48. -54. ]
[ 0.1 7.0999999 14.10000038 21.10000038 28.10000038
35.09999847 42.09999847 49.09999847 56.09999847 63.09999847]
[ 0.1 8.10000038 16.10000038 24.10000038 32.09999847
40.09999847 48.09999847 56.09999847 64.09999847 72.09999847]
[ 0. -9. -18. -27. -36. -45.
-54. -63. -72. -81. ]]

关于tensorflow - 如何使用 tf.cond 进行批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42159946/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com