gpt4 book ai didi

python - 函数中的嵌套渐变带 (TF2.0)

转载 作者:行者123 更新时间:2023-11-30 09:03:33 24 4
gpt4 key购买 nike

我尝试实现 MAML。因此,我需要模型的副本(model_copy)来进行一步训练,然后我需要在丢失 model_copy 的情况下训练我的 meta_model。

我想在函数中进行 model_copy 的训练。如果我将代码复制到函数中,我将不会获得正确的gradients_meta(它们将全部没有)。

看起来这些图没有连接 - 我怎样才能连接这些图?

知道我做错了什么吗?我观察了很多变量,但这似乎没有什么区别..

以下是重现此问题的代码:

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as keras_backend


def copy_model(model):
copied_model = keras.Sequential()
copied_model.add(keras.layers.Dense(5, input_shape=(1,)))
copied_model.add(keras.layers.Dense(1))
copied_model.set_weights(model.get_weights())
return copied_model


def compute_loss(model, x, y):
logits = model(x) # prediction of my model
mse = keras_backend.mean(keras.losses.mean_squared_error(y, logits)) # compute loss between prediciton and label/truth
return mse, logits


# meta_model to learn in outer gradient tape
meta_model = keras.Sequential()
meta_model.add(keras.layers.Dense(5, input_shape=(1,)))
meta_model.add(keras.layers.Dense(1))

# optimizer for training
optimizer = keras.optimizers.Adam()


# function to calculate model_copys params
def do_calc(x, y, meta_model):
with tf.GradientTape() as gg:
model_copy = copy_model(meta_model)
gg.watch(x)
gg.watch(meta_model.trainable_variables)
gg.watch(model_copy.trainable_variables)
loss, _ = compute_loss(model_copy, x, y)
gradient = gg.gradient(loss, model_copy.trainable_variables)
optimizer.apply_gradients(zip(gradient, model_copy.trainable_variables))
return model_copy


# inputs for training
x = tf.constant(3.0, shape=(1, 1, 1))
y = tf.constant(3.0, shape=(1, 1, 1))

with tf.GradientTape() as g:

g.watch(x)
g.watch(y)

model_copy = do_calc(x, y, meta_model)
g.watch(model_copy.trainable_variables)
# calculate loss of model_copy
test_loss, _ = compute_loss(model_copy, x, y)
# build gradients for meta_model update
gradients_meta = g.gradient(test_loss, meta_model.trainable_variables)
# gradients always None !?!!11 elf
optimizer.apply_gradients(zip(gradients_meta, meta_model.trainable_variables))

预先感谢您的帮助。

最佳答案

我找到了解决办法:我需要以某种方式“连接”元模型和模型复制。

任何人都可以解释为什么这样做有效以及我如何使用“适当的”优化器来实现这一点?

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as keras_backend


def copy_model(model):
copied_model = keras.Sequential()
copied_model.add(keras.layers.Dense(5, input_shape=(1,)))
copied_model.add(keras.layers.Dense(1))
copied_model.set_weights(model.get_weights())
return copied_model


def compute_loss(model, x, y):
logits = model(x) # prediction of my model
mse = keras_backend.mean(keras.losses.mean_squared_error(y, logits)) # compute loss between prediciton and label/truth
return mse, logits


# meta_model to learn in outer gradient tape
meta_model = keras.Sequential()
meta_model.add(keras.layers.Dense(5, input_shape=(1,)))
meta_model.add(keras.layers.Dense(1))

# optimizer for training
optimizer = keras.optimizers.Adam()


# function to calculate model_copys params
def do_calc(meta_model, x, y, gg, alpha=0.01):
model_copy = copy_model(meta_model)
loss, _ = compute_loss(model_copy, x, y)
gradients = gg.gradient(loss, model_copy.trainable_variables)
k = 0
for layer in range(len(model_copy.layers)):
# calculate adapted parameters w/ gradient descent
# \theta_i' = \theta - \alpha * gradients
model_copy.layers[layer].kernel = tf.subtract(meta_model.layers[layer].kernel,
tf.multiply(alpha, gradients[k]))
model_copy.layers[layer].bias = tf.subtract(meta_model.layers[layer].bias,
tf.multiply(alpha, gradients[k + 1]))
k += 2
return model_copy


with tf.GradientTape() as g:
# inputs for training
x = tf.constant(3.0, shape=(1, 1, 1))
y = tf.constant(3.0, shape=(1, 1, 1))
adapted_models = []

# model_copy = meta_model
with tf.GradientTape() as gg:
model_copy = do_calc(meta_model, x, y, gg)

# calculate loss of model_copy
test_loss, _ = compute_loss(model_copy, x, y)
# build gradients for meta_model update
gradients_meta = g.gradient(test_loss, meta_model.trainable_variables)
# gradients work. Why???
optimizer.apply_gradients(zip(gradients_meta, meta_model.trainable_variables))

关于python - 函数中的嵌套渐变带 (TF2.0),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58856784/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com