gpt4 book ai didi

tensorflow - 混合精度训练导致 NaN 损失

转载 作者:行者123 更新时间:2023-12-05 06:01:34 29 4
gpt4 key购买 nike

我一直在关注 Mixed Precision Guide .因此,我正在设置:

keras.mixed_precision.set_global_policy(mixed_precision)

像这样包装优化器:

if mixed_precision.startswith('mixed'):
logger.info(f'Using LossScaleOptimizer for mixed-precision policy "{mixed_precision}"')
optimizer = keras.mixed_precision.LossScaleOptimizer(optimizer)

我的模型有简单的 Dense 层作为输出,我将其设置为“float32”

# Set dtype explicitly in last layer for mixed-precision training (float32 for numeric stability).
self.output_dense = layers.Dense(vocab_size, dtype=tf.float32)

和自定义的 train_step() 实现,我对此进行了修改:

with tf.GradientTape() as tape:
model_loss = self.loss_fn(
inputs,
y_true=y_true,
mask=mask
)

is_mixed_precision = isinstance(self.optimizer, mixed_precision.LossScaleOptimizer)

# We always want to return the unmodified model_loss for Tensorboard
if is_mixed_precision:
loss = self.optimizer.get_scaled_loss(model_loss)
else:
loss = model_loss

gradients = tape.gradient(loss, self.trainable_variables)

if is_mixed_precision:
gradients = self.optimizer.get_unscaled_gradients(gradients)

return model_loss, gradients

然而,一段时间后我的损失仍然变成了NaN:

enter image description here

在外面我正在确认该策略是否已被模型识别:

logger.info(f'Mixed-precision policy: {mixed_precision}')
logger.info(f'Compute dtype: {model.compute_dtype}')
logger.info(f'Variable dtype: {model.variable_dtype}')
keras.py:216] Mixed-precision policy: mixed_float16
keras.py:217] Compute dtype: float16
keras.py:218] Variable dtype: float32

但我可以看出这是由于 NaN 损失..

有什么明显我做错了或遗漏了什么吗?知道如何在此处追踪问题吗?

最佳答案

经过一番反射(reflection),我想我找到了问题所在。它位于我自定义的多头注意力层中。更具体地说,问题似乎出在我使用 value.dtype.min 以便将掩码应用于 logits 的地方,例如:

logits += value.dtype.min * (1.0 - mask)

有趣的是,这甚至在一开始就奏效了。考虑一下,您很有可能会从一开始就遇到下溢,但我能够训练模型一段时间,直到出现 NaN

无论如何,我的解决方案是给它一些空间,所以我简单地将 dtype 的最小值除以二:

logits += (value.dtype.min / 2.0) * (1.0 - mask)

关于tensorflow - 混合精度训练导致 NaN 损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67159157/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com