gpt4 book ai didi

machine-learning - 有没有办法使用 Tensorflow 实现自动化迁移学习?

转载 作者:行者123 更新时间:2023-11-30 08:52:13 26 4
gpt4 key购买 nike

我正在使用 Tensorflow 构建和训练多个神经网络。这些网络正在对相关任务(自然语言处理)进行监督学习。

我所有的神经网络的共同点是它们共享一些早期层(有些共享另外 2 个层)。

我希望能够共享一个架构中公共(public)层的训练权重来初始化另一个架构。

我目前做事的方式是,每次我想转移权重时,我都会编写一段单独的(临时)代码。这使我的项目变得困惑并且耗时。

有人知道一种方法可以让我自动化重量转移过程吗?例如,要自动检测公共(public)层,然后初始化相应的权重。

最佳答案

您可以创建 tf.Saver专门针对感兴趣的变量集,只要它们具有相同的名称,您就可以恢复另一个图中的变量。您可以使用集合来存储这些变量,然后为集合创建保存程序:

TRANSFERABLE_VARIABLES = "transferable_variable"
# ...
my_var = tf.get_variable(...)
tf.add_to_collection(TRANSFERABLE_VARIABLES, my_var)
# ...
saver = tf.Saver(tf.get_collection(TRANSFERABLE_VARIABLES), ...)

这应该允许您在一个图中调用save并在另一个图中调用restore来传输权重。

如果您想避免将任何内容写入磁盘,那么我认为除了手动复制/粘贴这些值之外没有其他方法。但是,通过使用集合和完全相同的构建过程,这也可以在相当程度上实现自动化:

model1_graph = create_model1()
model2_graph = create_model2()

with model1_graph.as_default(), tf.Session() as sess:
# Train...
# Retrieve learned weights
transferable_weights = sess.run(tf.get_collection(TRANSFERABLE_VARIABLES))

with model2_graph.as_default(), tf.Session() as sess:
# Load weights from the other model
for var, weight in zip(tf.get_collection(TRANSFERABLE_VARIABLES),
transferable_weights):
var.load(weight, sess)
# Continue training...

同样,这仅在公共(public)层的构造相同的情况下才有效,因为两个图的集合中变量的顺序应该相同。

更新:

如果您想确保恢复的变量不用于训练,您有几种可能性,尽管它们可能都需要对代码进行更多更改。 可训练变量只是包含在集合 tf.GrapKeys.TRAINABLE_VARIABLES 中的变量。 ,因此当您在第二个图中创建传输变量时,您可以只说 trainable=False ,恢复过程应该是相同的。如果您想要更加动态并自动执行,这或多或少是可能的,但请记住这一点:必须在创建优化器之前知道必须用于训练的变量列表,并且之后无法更改(无需创建新的优化器)。知道了这一点,我认为没有任何解决方案不通过传递包含第一张图中的可传递变量名称的列表。例如:

with model1_graph.as_default():
transferable_names = [v.name for v in tf.get_collection(TRANSFERABLE_VARIABLES)]

然后,在第二个图的构建过程中,在定义模型之后、创建优化器之前,您可以执行以下操作:

train_vars = [v for v in tf.get_collection(tf.GrapKeys.TRAINABLE_VARIABLES)
if v.name not in transferable_names]
# Assuming that `model2_graph` is the current default graph
tf.get_default_graph().clear_collection(tf.GrapKeys.TRAINABLE_VARIABLES)
for v in train_vars:
tf.add_to_collection(tf.GrapKeys.TRAINABLE_VARIABLES, v)
# Create the optimizer...

另一个选择是不修改集合tf.GrapKeys.TRAINABLE_VARIABLES,而是传递您想要优化的变量列表(示例中的train_vars)作为参数 var_listminimize优化器的方法。原则上我个人不太喜欢这个,因为我认为集合的内容应该符合它们的语义目的(毕竟,代码的其他部分可能会使用相同的集合用于其他目的),但这取决于我猜测的情况。

关于machine-learning - 有没有办法使用 Tensorflow 实现自动化迁移学习?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45569330/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com