gpt4 book ai didi

python - 改变 tensorflow 图并恢复训练

转载 作者:行者123 更新时间:2023-12-04 01:27:33 25 4
gpt4 key购买 nike

我正在尝试加载 MCnet model 的预训练权重并恢复训练。此处提供的预训练模型使用参数 K=4, T=7 进行训练。但是,我想要一个参数为 K=4,T=1 的模型。我不想从头开始训练,而是想从这个预训练模型中加载权重。但由于图表已更改,我无法加载预训练模型。

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Assign requires shapes of both tensors to match. lhs shape= [5,5,15,64] rhs shape= [5,5,33,64]
[[node save/Assign_13 (defined at /media/nagabhushan/Data02/SNB/IISc/Research/04_Gaming_Video_Prediction/Workspace/VideoPrediction/Literature/01_MCnet/src/snb/mcnet.py:108) ]]

是否可以使用新图加载预训练模型?

我尝试过的:
以前,我想将预训练模型从旧版本的 tensorflow 移植到新版本。我得到了 this answer在 SO 中帮助我移植模型。这个想法是创建新图形并从保存的图形加载新图形中存在的变量。

with tf.Session() as sess:
_ = MCNET(image_size=[240, 320], batch_size=8, K=4, T=1, c_dim=3, checkpoint_dir=None, is_train=True)
tf.global_variables_initializer().run(session=sess)

ckpt_vars = tf.train.list_variables(model_path.as_posix())
ass_ops = []
for dst_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
for (ckpt_var, ckpt_shape) in ckpt_vars:
if dst_var.name.split(":")[0] == ckpt_var and dst_var.shape == ckpt_shape:
value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
ass_ops.append(tf.assign(dst_var, value))

# Assign the variables
sess.run(ass_ops)
saver = tf.train.Saver()
saver.save(sess, save_path.as_posix())

我在这里尝试了同样的方法并且成功了,这意味着我得到了一个新的 K=4,T=1 训练模型。但我不确定它是否有效!我的意思是,权重有意义吗?这是正确的做法吗?

关于模型的信息:
MCnet 是一种用于视频预测的模型,即给定 K 过去的帧,它可以预测接下来的 T 帧。

感谢任何帮助

最佳答案

MCnet 模型有一个生成器和一个鉴别器。生成器是基于 LSTM 的,因此通过改变时间步长 T 的数量来加载权重没有问题。然而,正如他们编码的那样,鉴别器是卷积的。为了在视频上应用卷积层,它们在 channel 维度上连接帧。使用 K=4,T=7,您将获得长度为 11 且具有 3 channel 的视频。当您沿 channel 维度连接它们时,您会得到一个具有 33 channel 的图像。当他们定义判别器时,他们将判别器的第一层定义为具有 33 输入 channel ,因此权重具有相似的维度。但是对于 K=4,T=1,视频长度为 5 并且最终图像有 15 channel ,因此权重将有 15 个 channel .这是您观察到的不匹配错误。要解决此问题,您可以仅从前 15 个 channel 中选取权重(我想不出更好的方法)。代码如下:

with tf.Session() as sess:
_ = MCNET(image_size=[240, 320], batch_size=8, K=4, T=1, c_dim=3, checkpoint_dir=None, is_train=True)
tf.global_variables_initializer().run(session=sess)

ckpt_vars = tf.train.list_variables(model_path.as_posix())
ass_ops = []
for dst_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
for (ckpt_var, ckpt_shape) in ckpt_vars:
if dst_var.name.split(":")[0] == ckpt_var:
if dst_var.shape == ckpt_shape:
value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
ass_ops.append(tf.assign(dst_var, value))
else:
value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
if dst_var.shape[2] <= value.shape[2]:
adjusted_value = value[:, :, :dst_var.shape[2]]
else:
adjusted_value = numpy.random.random(dst_var.shape)
adjusted_value[:, :, :value.shape[2], ...] = value
ass_ops.append(tf.assign(dst_var, adjusted_value))

# Assign the variables
sess.run(ass_ops)
saver = tf.train.Saver()
saver.save(sess, save_path.as_posix())

关于python - 改变 tensorflow 图并恢复训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61612193/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com