gpt4 book ai didi

while-loop - 带有 TensorArray 的 TensorFlow while 循环

转载 作者:行者123 更新时间:2023-12-02 22:22:01 31 4
gpt4 key购买 nike

import tensorflow as tf

B = 3
D = 4
T = 5

tf.reset_default_graph()

xs = tf.placeholder(shape=[T, B, D], dtype=tf.float32)

with tf.variable_scope("RNN"):
GRUcell = tf.contrib.rnn.GRUCell(num_units = D)
cell = tf.contrib.rnn.MultiRNNCell([GRUcell])

output_ta = tf.TensorArray(size=T, dtype=tf.float32)
input_ta = tf.TensorArray(size=T, dtype=tf.float32)
input_ta.unstack(xs)

def body(time, output_ta_t, state):
xt = input_ta.read(time)
new_output, new_state = cell(xt, state)
output_ta_t.write(time, new_output)
return (time+1, output_ta_t, new_state)

def condition(time, output, state):
return time < T

time = 0
state = cell.zero_state(B, tf.float32)

time_final, output_ta_final, state_final = tf.while_loop(
cond=condition,
body=body,
loop_vars=(time, output_ta, state))

output_final = output_ta_final.stack()

我运行它

x = np.random.normal(size=(T, B, D))
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_final_, state_final_ = sess.run(fetches = [output_final, state_final], feed_dict = {xs:x})

我想了解如何正确使用与 TensorFlow while 循环相关的 TensorArray。在上面的示例中,我收到以下错误:

InvalidArgumentError: TensorArray RNN/TensorArray_1_21: Could not read from TensorArray index 0 because it has not yet been written to.

我不明白这个“无法从 TensorArray 索引 0 读取”。我想我通过 unstack 写入 TensorArray input_ta 并在 while 主体中写入 output_ta 。我做错了什么?感谢您的帮助。

最佳答案

解决办法就是改变

input_ta.unstack(xs)

input_ta = input_ta.unstack(xs)

并进行类似的更改

output_ta_t.write(time, new_output)

output_ta_t = output_ta_t.write(time, new_output)

通过这两项更改,代码将按预期运行。

关于while-loop - 带有 TensorArray 的 TensorFlow while 循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43701631/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com