gpt4 book ai didi

concatenation - 如何在 TensorFlow 中对批处理进行切片并对每个切片应用操作

转载 作者:行者123 更新时间:2023-12-01 12:29:33 25 4
gpt4 key购买 nike

我是 TensorFlow 的初学者,我正在尝试实现一个将批处理作为输入的函数。它必须将这个批次分成几个,对它们应用一些操作,然后将它们连接起来以构建一个新的张量以返回。通过我的阅读,我发现有一些实现的函数,如 input_slice_producer 和 batch_join,但我没有开始使用它们。我附上了我在下面找到的解决方案,但它有点慢,不正确并且无法检测当前批次的大小。有没有人知道这样做的更好方法?

def model(x):

W_1 = tf.Variable(tf.random_normal([6,1]),name="W_1")
x_size = x.get_shape().as_list()[0]
# x is a batch of bigger input of shape [None,6], so I couldn't
# get the proper size of the batch when feeding it
if x_size == None:
x_size= batch_size
#intialize the y_res
dummy_x = tf.slice(x,[0,0],[1,6])
result = tf.reduce_sum(tf.mul(dummy_x,W_1))
y_res = tf.zeros([1], tf.float32)
y_res = result
#go throw all slices and concatenate them to get result
for i in range(1,x_size):
dummy_x = tf.slice(x,[i,0],[1,6])
result = tf.reduce_sum(tf.mul(dummy_x,W_1))
y_res = tf.concat(0, [y_res, result])

return y_res

最佳答案

TensorFlow 函数 tf.map_fn(fn, elems) 允许您将函数 ( fn ) 应用于张量的每个切片 ( elems )。例如,您可以按如下方式表达您的程序:

def model(x):
W_1 = tf.Variable(tf.random_normal([6, 1]), name="W_1")

def fn(x_slice):
return tf.reduce_sum(x_slice, W_1)

return tf.map_fn(fn, x)

还可以使用 tf.mul() 上的广播更简洁地实现您的操作。运算符,使用 NumPy broadcasting semantics ,以及 axis论据 tf.reduce_sum() .

关于concatenation - 如何在 TensorFlow 中对批处理进行切片并对每个切片应用操作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35575982/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com