gpt4 book ai didi

tensorflow Estimator 无法初始化全局变量

转载 作者:行者123 更新时间:2023-12-04 17:49:37 31 4
gpt4 key购买 nike

我正在使用 tensorflow slim resnet_v2 来提取图像特征。resnet_v2_152.ckpt 来自:resnet_v2_152.ckpt这是我的代码。

import tensorflow as tf

import tensorflow.contrib.slim.python.slim.nets.resnet_v2 as resnet_v2


def cnn_model_fn(features, labels, mode):
net, end_points = resnet_v2.resnet_v2_152(inputs=features, is_training=mode == tf.estimator.ModeKeys.TRAIN)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=net)
else:
raise NotImplementedError('only support predict!')


def parse_filename(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image_resized = tf.image.resize_images(image_decoded, [256, 256])
return image_resized


def dataset_input_fn(dataset, num_epochs=None, batch_size=128, shuffle=False, buffer_size=1000, seed=None):
def input_fn():
d = dataset.repeat(num_epochs).batch(batch_size)
if shuffle:
d = d.shuffle(buffer_size)
iterator = d.make_one_shot_iterator()
next_example = iterator.get_next()
return next_example

return input_fn


filenames = sorted(tf.gfile.Glob('/root/data/COCO/download/val2014/*'))
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames).map(parse_filename)

input_fn = dataset_input_fn(dataset, num_epochs=1, batch_size=1, shuffle=False)

estimator = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=None)

es = estimator.predict(input_fn=input_fn,
checkpoint_path='/root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt')
print(es.__next__())


print("Done!")

我得到了这样的错误:

2017-09-10 22:06:36.875590: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
[[Node: save/RestoreV2_242/_309 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1240_save/RestoreV2_242", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

我想我可以通过将 conv1/biases 初始化为 0 来解决这个问题,但是 tensorflow Estimator 没有给我这样的功能。我该如何解决?

最佳答案

我认为,您希望加载预训练的权重,而不仅仅是在 resnet 中初始化变量。您应该考虑使用 tf.train.Scaffold目的。

模型例程应该是这样的

def cnn_model_fn(features, labels, mode):
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
logits, end_points = resnet_v2.resnet_v2_152(features,
is_training=mode == tf.estimator.ModeKeys.TRAIN)

checkpoint_file = 'resnet_v2_152.ckpt'
init_fn = slim.assign_from_checkpoint_fn(
checkpoint_file,
[var for var in tf.global_variables()])

saver = tf.train.Saver(max_to_keep=10)
scaffold = tf.train.Scaffold(init_fn=init_fn, saver=saver)

if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,
predictions={'logits': logits},
scaffold=scaffold)
else:
raise NotImplementedError('only support predict!')

关于tensorflow Estimator 无法初始化全局变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46141696/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com