gpt4 book ai didi

TensorFlow:加载检查点,但只是其中的一部分(卷积层)

转载 作者:行者123 更新时间:2023-12-04 19:41:57 25 4
gpt4 key购买 nike

是否可以仅从一个检查点文件中加载特定层(卷积层)?

我已经训练了一些完全监督的 CNN 并保存了我的进度(我正在做对象定位)。要进行自动标记,我想从我当前的模型中构建一个弱监督的 CNN……但由于弱监督版本有不同的全连接层,我想只选择我的 TensorFlow 检查点文件的卷积过滤器.

当然,我可以手动保存相应层的权重,但由于它们已经包含在 TensorFlow 的检查点文件中,我想在那里提取它们,以便拥有一个存储文件。

最佳答案

TensorFlow 2.1 有许多不同的公共(public)设施用于加载检查点(model.saveCheckpointsaved_model 等),但据我所知,它们都没有过滤 API。因此,让我建议一个使用 TF2.1 内部开发测试工具的困难案例的片段。

checkpoint_filename = '/path/to/our/weird/checkpoint.ckpt'
model = tf.keras.Model( ... ) # TF2.0 Model to initialize with the above checkpoint
variables_to_load = [ ... ] # List of model weight names to update.

from tensorflow.python.training.checkpoint_utils import load_checkpoint, list_variables

reader = load_checkpoint(checkpoint_filename)
for w in model.weights:
name=w.name.split(':')[0] # See (b/29227106)
if name in variables_to_load:
print(f"Updating {name}")
w.assign(reader.get_tensor(
# (Optional) Handle variable renaming
{'/var_name1/in/model':'/var_name1/in/checkpoint',
'/var_name2/in/model':'/var_name2/in/checkpoint',
# ... and so on
}.get(name,name)))


注: model.weightslist_variables可能有助于检查模型和检查点中的变量

另请注意,此方法不会恢复模型的优化器状态。

关于TensorFlow:加载检查点,但只是其中的一部分(卷积层),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38159532/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com