gpt4 book ai didi

python - TensorFlow TypeError : Fetch argument None has invalid type ?

转载 作者:IT老高 更新时间:2023-10-28 20:52:06 28 4
gpt4 key购买 nike

我正在基于 the TensorFlow tutorial 松散地构建 RNN .

我的模型的相关部分如下:

input_sequence = tf.placeholder(tf.float32, [BATCH_SIZE, TIME_STEPS, PIXEL_COUNT + AUX_INPUTS])
output_actual = tf.placeholder(tf.float32, [BATCH_SIZE, OUTPUT_SIZE])

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(CELL_SIZE, state_is_tuple=False)
stacked_lstm = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * CELL_LAYERS, state_is_tuple=False)

initial_state = state = stacked_lstm.zero_state(BATCH_SIZE, tf.float32)
outputs = []

with tf.variable_scope("LSTM"):
for step in xrange(TIME_STEPS):
if step > 0:
tf.get_variable_scope().reuse_variables()
cell_output, state = stacked_lstm(input_sequence[:, step, :], state)
outputs.append(cell_output)

final_state = state

还有喂食:

cross_entropy = tf.reduce_mean(-tf.reduce_sum(output_actual * tf.log(prediction), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(output_actual, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
numpy_state = initial_state.eval()

for i in xrange(1, ITERATIONS):
batch = DI.next_batch()

print i, type(batch[0]), np.array(batch[1]).shape, numpy_state.shape

if i % LOG_STEP == 0:
train_accuracy = accuracy.eval(feed_dict={
initial_state: numpy_state,
input_sequence: batch[0],
output_actual: batch[1]
})

print "Iteration " + str(i) + " Training Accuracy " + str(train_accuracy)

numpy_state, train_step = sess.run([final_state, train_step], feed_dict={
initial_state: numpy_state,
input_sequence: batch[0],
output_actual: batch[1]
})

当我运行它时,我收到以下错误:

Traceback (most recent call last):
File "/home/agupta/Documents/Projects/Image-Recognition-with-LSTM/RNN/feature_tracking/model.py", line 109, in <module>
output_actual: batch[1]
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 698, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 838, in _run
fetch_handler = _FetchHandler(self._graph, fetches)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 355, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 181, in for_fetch
return _ListFetchMapper(fetch)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 288, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 178, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <type 'NoneType'>

也许最奇怪的部分是这个错误在 第二次 迭代中被抛出,而第一次工作完全正常。我正在扯我的头发试图解决这个问题,所以任何帮助将不胜感激。

最佳答案

您正在将 train_step 变量重新分配给 sess.run() 结果的第二个元素(恰好是 None)。因此,在第二次迭代中,train_stepNone,从而导致错误。

幸运的是,修复很简单:

for i in xrange(1, ITERATIONS):

# ...

# Discard the second element of the result.
numpy_state, _ = sess.run([final_state, train_step], feed_dict={
initial_state: numpy_state,
input_sequence: batch[0],
output_actual: batch[1]
})

关于python - TensorFlow TypeError : Fetch argument None has invalid type <type 'NoneType' >?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39114832/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com