gpt4 book ai didi

python - 为什么 tensorflow 在多次训练 Estimator 时说 Tensor 不是该图的元素?

转载 作者:太空宇宙 更新时间:2023-11-03 12:43:58 26 4
gpt4 key购买 nike

考虑以下代码:

import tensorflow as tf

from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.contrib.keras.api.keras.layers import Dense


def model_fn_1(features, labels, mode):
x = [[1]]
labels = [[10]]
m = tf.constant([[1, 2], [3, 4]], tf.float32)
lookup = tf.nn.embedding_lookup(m, x, name='embedding_lookup')

preds = Dense(1)(lookup)
loss = tf.reduce_mean(labels - preds)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step())

eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)}
return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)


model_1 = tf.estimator.Estimator(model_fn_1)
model_1.train(input_fn=lambda: None, steps=1)

正如预期的那样,我可以多次执行 model_1.train(input_fn=lambda: None, steps=1) 并且训练将从上一次执行继续。

现在,考虑以下代码:

import tensorflow as tf
import numpy as np

from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.contrib.keras.api.keras.layers import Embedding, Dense

def model_fn_2(features, labels, mode):
x = tf.constant([[1]])
labels = [[10]]
m = np.array([[1, 2], [3, 4]])
m = Embedding(2, 2, weights=[m], input_length=1, name='embedding_lookup')
lookup = m(x)

preds = Dense(1)(lookup)
loss = tf.reduce_mean(labels - preds)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step())

eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)}
return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)


model_2 = tf.estimator.Estimator(model_fn_2)
model_2.train(input_fn=lambda: None, steps=1)

在这种情况下,我只能执行一次 model_2.train(input_fn=lambda: None, steps=1) ,当我再次尝试执行它时,出现以下错误:

ValueError: Fetch argument cannot be interpreted as a Tensor. (Tensor Tensor("embedding_lookup/embeddings:0", shape=(2, 2), dtype=float32_ref) is not an element of this graph.)

为什么会发生这种情况,我该如何解决?

最佳答案

这可能是 tensorflow keras 后端中的错误或不受支持的情况: session 被全局缓存并且未被清除。您可以通过调用手动清除它:

from tensorflow.contrib.keras.python.keras.backend import clear_session
clear_session()

...在 train 调用之间。

简短原因:第二个 train 调用构建了一个包含新节点的新图,但是底层 session 保留了之前的图,这使得它们不兼容。

更新。在最新的tensorflow中,keras被移到了另一个包中,现在看起来更简单了:

from keras.backend import clear_session
clear_session()

关于python - 为什么 tensorflow 在多次训练 Estimator 时说 Tensor 不是该图的元素?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46911596/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com