gpt4 book ai didi

python - 使用@tf.function 进行自定义 tensorflow 训练的内存泄漏

转载 作者:行者123 更新时间:2023-12-04 11:33:23 36 4
gpt4 key购买 nike

我正在尝试为 TF2/Keras 编写自己的训练循环,遵循官方的 Keras 演练。 Vanilla 版本就像一个魅力,但是当我尝试添加 @tf.function 时我的训练步骤的装饰器,一些内存泄漏占用了我所有的内存并且我失去了对我的机器的控制,有谁知道发生了什么?
代码的重要部分如下所示:

@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = siamese_network(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, siamese_network.trainable_weights)
optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value

@tf.function
def test_step(x, y):
val_logits = siamese_network(x, training=False)
val_acc_metric.update_state(y, val_logits)
val_prec_metric.update_state(y_batch_val, val_logits)
val_rec_metric.update_state(y_batch_val, val_logits)


for epoch in range(epochs):
step_time = 0
epoch_time = time.time()
print("Start of {} epoch".format(epoch))
for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
if step > steps_epoch:
break

loss_value = train_step(x_batch_train, y_batch_train)
train_acc = train_acc_metric.result()
train_acc_metric.reset_states()

for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
if val_step>validation_steps:
break
test_step(x_batch_val, y_batch_val)

val_acc = val_acc_metric.result()
val_prec = val_prec_metric.result()
val_rec = val_rec_metric.result()

val_acc_metric.reset_states()
val_prec_metric.reset_states()
val_rec_metric.reset_states()
如果我评论 @tf.function行,不会发生内存泄漏,但步骤时间慢了 3 倍。我的猜测是以某种方式在每个时代或类似的东西中再次创建了该图,但我不知道如何解决它。
这是我正在关注的教程: https://keras.io/guides/writing_a_training_loop_from_scratch/

最佳答案

tl;博士;
TensorFlow 可能会为传递给修饰函数的每组唯一参数值生成一个新图。确保您通过形状一致 Tensor反对 test_steptrain_step而不是 python 对象。
细节
这是黑暗中的刺。虽然我从未尝试过 @tf.function ,我确实在 the documentation 中发现了以下警告:

tf.function also treats any pure Python value as opaque objects, and builds a separate graph for each set of Python arguments that it encounters.



Caution: Passing python scalars or lists as arguments to tf.function will always build a new graph. To avoid this, pass numeric arguments as Tensors whenever possible


最后:

A Function determines whether to reuse a traced ConcreteFunction by computing a cache key from an input's args and kwargs. A cache key is a key that identifies a ConcreteFunction based on the input args and kwargs of the Function call, according to the following rules (which may change):

  • The key generated for a tf.Tensor is its shape and dtype.
  • The key generated for a tf.Variable is a unique variable id.
  • The key generated for a Python primitive (like int, float, str) is its value.
  • The key generated for nested dicts, lists, tuples, namedtuples, and attrs is the flattened tuple of leaf-keys (see nest.flatten). (As a result of this flattening, calling a concrete function with a different nesting structure than the one used during tracing will result in a TypeError).
  • For all other Python types the key is unique to the object. This way a function or method is traced independently for each instance it is called with.


我从这一切中得到的是,如果您没有将大小一致的 Tensor 对象传递给您的 @tf.function -ified 函数(也许您改用 Python 集合或原语),很可能您正在使用传入的每个不同参数值创建函数的新图形版本。我猜这可能会造成内存爆炸行为重见。我不知道你的 test_dstrain_ds正在创建对象,但您可能希望确保创建的对象是 enumerate(blah_ds)像教程中一样返回张量,或者至少在传递给您的 test_step 之前将值转换为张量和 train_step职能。

关于python - 使用@tf.function 进行自定义 tensorflow 训练的内存泄漏,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67116476/

36 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com