gpt4 book ai didi

python - Tensorflow 2 + Keras 的知识蒸馏损失

转载 作者:行者123 更新时间:2023-12-04 11:41:13 28 4
gpt4 key购买 nike

我正在尝试实现一个非常简单的 keras 模型,该模型使用来自另一个模型的知识蒸馏 [1]。
粗略地说,我需要替换原来的损失L(y_true, y_pred)来自 L(y_true, y_pred)+L(y_teacher_pred, y_pred)哪里y_teacher_pred是另一个模型的预测。

我试过做

def create_student_model_with_distillation(teacher_model):

inp = tf.keras.layers.Input(shape=(21,))

model = tf.keras.models.Sequential()
model.add(inp)

model.add(...)
model.add(tf.keras.layers.Dense(units=1))

teacher_pred = teacher_model(inp)

def my_loss(y_true,y_pred):
loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
loss += tf.keras.losses.mean_squared_error(teacher_pred, y_pred)
return loss

model.compile(loss=my_loss, optimizer='adam')

return model

但是,当我尝试拨打 fit 时在我的模型上,我得到

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.

我该如何解决这个问题?

引用资料

[1] https://arxiv.org/abs/1503.02531

最佳答案

实际上,这篇博文是对您问题的回答:keras blog
但简而言之 - 您应该使用新的 TF2 API 并调用老师的 predict之前tf.GradientTape()堵塞:

def train_step(self, data):
# Unpack data
x, y = data

# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)

with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)

# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

关于python - Tensorflow 2 + Keras 的知识蒸馏损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59137907/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com