gpt4 book ai didi

python - 从 tensorflow 模型中获取权重

转载 作者:太空狗 更新时间:2023-10-30 02:28:34 25 4
gpt4 key购买 nike

您好,我想从 tensorflow 微调 VGG 模型。我有两个问题。

如何从网络中获取权重? trainable_variables 为我返回空列表。

我使用了这里的现有模型:https://github.com/ry/tensorflow-vgg16 .我找到了关于获取权重的帖子,但是由于 import_graph_def 这对我不起作用。 Get the value of some weights in a model trained by TensorFlow

import tensorflow as tf
import PIL.Image
import numpy as np

with open("../vgg16.tfmodel", mode='rb') as f:
fileContent = f.read()

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print("graph loaded from disk")

graph = tf.get_default_graph()

cat = np.asarray(PIL.Image.open('../cat224.jpg'))
print(cat.shape)
init = tf.initialize_all_variables()

with tf.Session(graph=graph) as sess:
print(tf.trainable_variables() )
sess.run(init)

最佳答案

pretrained VGG-16 model将所有模型参数编码为 tf.constant()操作。 (例如,参见对 tf.constant() here 的调用。)因此,模型参数不会出现在 tf.trainable_variables() 中,并且模型在没有大量手术的情况下是不可改变的:您需要将常量节点替换为 tf.Variable以相同值开始的对象,以便继续训练。

一般来说,在导入图形进行再训练时,tf.train.import_meta_graph()应该使用函数,因为这个函数加载额外的元数据(包括变量集合)。 tf.import_graph_def()函数处于较低级别,不会填充这些集合。

关于python - 从 tensorflow 模型中获取权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36412065/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com