gpt4 book ai didi

TensorFlow:NotFoundError:在检查点中找不到 key

转载 作者:行者123 更新时间:2023-12-03 00:39:43 26 4
gpt4 key购买 nike

我已经训练 TensorFlow 模型大约一周时间,偶尔进行微调。

今天,当我尝试微调模型时,出现错误:

tensorflow.python.framework.errors_impl.NotFoundError: Key conv_classifier/loss/total_loss/avg not found in checkpoint
[[Node: save/RestoreV2_37 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_37/tensor_names, save/RestoreV2_37/shape_and_slices)]]

使用inspect_checkpoint.py,我看到检查点文件现在有两个空层:

...
conv_decode4/ort_weights/Momentum (DT_FLOAT) [7,7,64,64]
loss/cross_entropy/avg (DT_FLOAT) []
loss/total_loss/avg (DT_FLOAT) []
up1/up_filter (DT_FLOAT) [2,2,64,64]
...

如何解决这个问题?

解决方案:

为清楚起见,对以下 Mrry 的建议进行了编辑:

code_to_checkpoint_variable_map = {var.op.name: var for var in tf.global_variables()}
for code_variable_name, checkpoint_variable_name in {
"inference/conv_classifier/weight_loss/avg" : "loss/weight_loss/avg",
"inference/conv_classifier/loss/total_loss/avg" : "loss/total_loss/avg",
"inference/conv_classifier/loss/cross_entropy/avg": "loss/cross_entropy/avg",
}.items():
code_to_checkpoint_variable_map[checkpoint_variable_name] = code_to_checkpoint_variable_map[code_variable_name]
del code_to_checkpoint_variable_map[code_variable_name]

saver = tf.train.Saver(code_to_checkpoint_variable_map)
saver.restore(sess, tf.train.latest_checkpoint('./logs'))

最佳答案

幸运的是,您的检查点看起来并没有损坏,而是程序中的某些变量已被重命名。我假设名为 "loss/total_loss/avg" 的检查点值应恢复到名为 "conv_classifier/loss/total_loss/avg" 的变量。您可以通过在创建 tf.train.Saver 时传递自定义 var_list 来解决此问题。

name_to_var_map = {var.op.name: var for var in tf.global_variables()}

name_to_var_map["loss/total_loss/avg"] = name_to_var_map[
"conv_classifier/loss/total_loss/avg"]
del name_to_var_map["conv_classifier/loss/total_loss/avg"]

# Depending on how the names have changed, you may also need to do:
# name_to_var_map["loss/cross_entropy/avg"] = name_to_var_map[
# "conv_classifier/loss/cross_entropy/avg"]
# del name_to_var_map["conv_classifier/loss/cross_entropy/avg"]

saver = tf.train.Saver(name_to_var_map)

然后您可以使用 saver.restore() 恢复您的模型。或者,您可以使用此方法来恢复模型,并使用默认构造的 tf.train.Saver 将其保存为规范格式。

关于TensorFlow:NotFoundError:在检查点中找不到 key ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46697662/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com