gpt4 book ai didi

python - 在 tensorflow 中初始化 inception_v3.ckpt 的权重

转载 作者:行者123 更新时间:2023-11-30 22:21:53 25 4
gpt4 key购买 nike

在 tensorflow 中,我需要从 inception_v3 预训练模型加载权重,以便在以下代码中使用:

with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=weights_regularizer,
trainable=False):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net, end_points = inception_v3_base(images, scope=scope)
with tf.variable_scope("logits"):
shape = net.get_shape()
net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
net = slim.dropout(
net,
keep_prob=dropout_keep_prob,
is_training=False,
scope="dropout")
net = slim.flatten(net, scope="flatten")

image_embeddings = tf.contrib.layers.fully_connected(
inputs=net,
num_outputs=512,
activation_fn=None,
weights_initializer=initializer,
biases_initializer=None,
scope=scope)

如何才能做到这一点?您能举个简单的例子吗?

上面的代码中有两个权重初始值设定项。我不知道我必须在哪一个模型中初始化权重,以及如何初始化?

谢谢

最佳答案

TL;DR:阅读下面列表中的第三点。

如何恢复模型的长篇通用解释

每当您需要从检查点加载权重时,您都需要匹配的模型定义才能在尝试恢复权重之前定义图形。这是必要的,因为检查点文件仅包含变量的值,它没有有关图本身结构的信息

模型结构可以通过不同的方式检索:

  • 检查点附带一个匹配的 .meta 文件。在这种情况下,导入元图,然后通过以下方式恢复权重:

    new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
    new_saver.restore(sess, 'my-save-dir/my-model-10000')
  • 检查点附带一个匹配的 .pb/.pbtxt 文件,其中包含序列化的 GraphDef。在这种情况下,从其定义加载图表,然后恢复权重:

    • 对于.pbtxt:

      with open('graph.pbtxt', 'r') as f:
      graph_def = tf.GraphDef()
      file_content = f.read()
      text_format.Merge(file_content, graph_def)
      tf.import_graph_def(graph_def, name='')
      saver = tf.train.Saver() # note: it is important that this is defined AFTER you import the graph definition or it won't find any variables in the graph to restore
      saver.restore(sess, "/tmp/model.ckpt")
    • 对于.pb:

      with gfile.FastGFile('graph.pb','rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')
      saver = tf.train.Saver() # note: it is important that this is defined AFTER you import the graph definition or it won't find any variables in the graph to restore
      saver.restore(sess, "/tmp/model.ckpt")
  • 检查点附带一个包含模型定义的匹配 python 文件。在这种情况下,请通读该文件的文档并找到需要调用来定义模型的函数。然后,在脚本中导入该函数,在定义saver之前调用它,然后从检查点恢复变量的值:

    from inception_v3 import inception_v3

    logits, endpoints = inception_v3()
    saver = tf.train.Saver() # as above, it is important that this is defined after you define the graph, or it won't find any variables.
    saver.restore(sess, 'inception_v3.ckpt')

    注意:对于这种情况,您需要完全调用保存检查点时调用的函数(除非您有选择地渴望恢复某些变量),否则恢复将失败并出现错误。

关于python - 在 tensorflow 中初始化 inception_v3.ckpt 的权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48496552/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com