gpt4 book ai didi

python - 在体内创建变量时正确使用 tf.while_loop

转载 作者:太空狗 更新时间:2023-10-30 00:18:32 27 4
gpt4 key购买 nike

我在 Tensorflow 中使用 while_loop 来迭代张量并提取给定维度上的特定切片。对于每一步,我都需要使用解码器 RNN 来生成一系列输出符号。我正在使用 tf.contrib.seq2seq 中提供的代码,特别是 tf.contrib.seq2seq.dynamic_decode .代码类似于以下内容:

def decoder_condition(i, data, source_seq_len, ta_outputs):
return tf.less(i, max_loop_len)

def decode_body(i, data, source_seq_len, ta_outputs):
curr_data = data[:, i, :]
curr_source_seq_len = source_seq_len[:, i, :]
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
2 * self.opt["encoder_rnn_h_size"],
curr_data,
memory_sequence_length=curr_source_seq_len
)
cell = GRUCell(num_units)
cell = AttentionWrapper(cell, attention_mechanism)
# ... other code that initialises all the variables required
# for the RNN decoder
outputs = tf.contrib.seq2seq.dynamic_decode(
decoder,
maximum_iterations=self.opt["max_sys_seq_len"],
swap_memory=True
)
with tf.control_dependencies([outputs)]:
ta_outputs = ta_outputs.write(i, outputs)

return i+1, data, ta_outputs

loop_index = tf.constant(0)
gen_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
outputs = tf.while_loop(
decoder_condition,
decoder_body,
loop_vars=[
loop_index,
data,
data_source_len,
ta_outputs
],
swap_memory=True,
back_prop=True,
parallel_iterations=1
)

如您所见,我创建了不同的对象,这些对象具体取决于当前步骤的输入i。我在我当前的变量范围内使用 tf.AUTO_REUSE,这样即使我创建不同的对象,变量也会被重用。不幸的是,我的解码器似乎没有正确训练,因为它不断生成不正确的值。我已经检查了解码器 RNN 的输入数据,一切都是正确的。我怀疑在 TensorFlow 如何管理 TensorArray 和 while_loop 方面我做得不对。

所以我的主要问题是:

  1. TensorFlow 是否正确传播它在 while 循环内创建的每个变量的梯度?
  2. 是否可以在 while 循环内创建依赖于使用循环索引获得的张量的特定切片的对象?
  3. backprop 参数是否保证在训练期间传播梯度?是否应该在推理期间将其设置为 False?
  4. 一般来说,是否有任何完整性检查可用于发现我的实现中可能存在的错误?

谢谢!

更新:不知道为什么,但似乎有一个 Unresolved 问题与在 while 循环中调用自定义操作的可能性有关,如下所述:https://github.com/tensorflow/tensorflow/issues/13616 .不幸的是,我对 TensorFlow 的内部了解不够,无法判断它是否与此完全相关。

更新 2:我用 PyTorch 解决了 :)

最佳答案

(1) 是

(2) 是的,只需使用循环索引对张量进行切片

(3) 普通用例无需设置backprop=False

(4) 使用 ML 模型的常用操作(玩具数据集、单独测试部件等)

重新更新2,尝试使用eager execution或者tf.contrib.autograph;两者都应该让你用纯 python 编写 while 循环。

关于python - 在体内创建变量时正确使用 tf.while_loop,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51466554/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com