gpt4 book ai didi

python - 如何使用 while_loop 和 tf.layers.batch_normalization 进行训练?

转载 作者:行者123 更新时间:2023-12-01 09:30:52 24 4
gpt4 key购买 nike

我需要在 while 循环体中添加一个 batch_normalization 层,但当我训练网络时它会崩溃。如果我删除x = tf.layers.batch_normalization(x,training=flag),一切都会正常。我可以在循环体中使用高级 API 吗?我不想使用 tf.nn.tf.nn.batch_normalization,因为这是一个简单的示例,而我的网络要复杂得多。

import tensorflow as tf
from data_pre import get_data

data, labels = get_data(
['../UCR_TS_Archive_2015/ItalyPowerDemand/ItalyPowerDemand_TRAIN'], 24, 2,True, 0, 2) #pylint: disable=line-too-long

flag = True

def cond(i, x):
return i < 1

def body(i, x):
x = tf.layers.conv1d(x, 1, 7, padding='same')
x = tf.layers.batch_normalization(x, training=flag)
x = tf.nn.relu(x)
return i + 1, x

_, y = tf.while_loop(cond, body, [0, data], back_prop=False)

y = tf.layers.flatten(y)
logits = tf.layers.dense(y, 2)

loss = tf.losses.mean_squared_error(labels, logits)
optimizer = tf.train.AdamOptimizer()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, tf.train.get_global_step())

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for _ in range(10):
sess.run(train_op)
coord.request_stop()
coord.join(threads)

这是错误信息:

Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1312, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1420, in _call_tf_sessionrun
status, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 516, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'gradients/mean_squared_error/div_grad/Neg' has inputs from different frames. The input 'while/batch_normalization/AssignMovingAvg_1' is in frame 'while/while_context'. The input 'one_hot' is in frame ''.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "./test.py", line 40, in <module>
sess.run(train_op)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 905, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1140, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'gradients/mean_squared_error/div_grad/Neg' has inputs from different frames. The input 'while/batch_normalization/AssignMovingAvg_1' is in frame 'while/while_context'. The input 'one_hot' is in frame ''.

最佳答案

我从 github 获得了帮助。如果您也遇到类似问题,可以通过The net using while_loop with batch_normalization can't train寻求帮助。

关于python - 如何使用 while_loop 和 tf.layers.batch_normalization 进行训练?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49980978/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com