gpt4 book ai didi

python - 如何使用 K.clear_session() 修复 Keras 中的内存泄漏问题?

转载 作者:行者123 更新时间:2023-12-05 07:03:20 24 4
gpt4 key购买 nike

我有一个网络,我通过批量输入我的数据来训练它,我正在使用 model.train_on_batch() 来做到这一点。如果我只运行这个训练部分,我会看到我的网络在 40 多个时期(到目前为止)以 3% 的 RAM 利用率训练得很好,每个时期有大约 2000 次迭代。当我尝试在每个纪元之后进行验证时(这也是分批发生的),内存泄漏非常严重,导致 90% 的 RAM 使用率和我的代码被杀死。所以我在过去几天尝试了一些事情,循环中的 model.predict() 似乎导致内存泄漏 open issue at Tensorflow GitHub .我尝试了 predict_on_batch(),同样的行为。 model(inputs, training=False) 似乎减缓了内存泄漏,而不是从 3% - 7% - 13% - 40% - 80% - 90%(60 秒的间隔)突然跳跃, 它以每分钟 1% 的速度增加。但在某个时候它也达到了 90%。我唯一要尝试使用这个 github 线程的是使用 K.clear_session()

我试着阅读了 K.clear_session() 的文档和一些 SO 帖子,都建议在创建多个模型时使用它,但我没有这样做。所以我的问题是,如果我有一个模型在循环中进行训练和评估,我应该在哪里使用 K.clear_session(),在每个纪元之后并在每个纪元之前重新加载保存的模型?这是正确的吗?

除此之外,我还得到拓扑排序错误 another open issue ,所以我想知道是不是因为我在循环训练,因为我的代码没有循环,这也会以某种方式导致内存泄漏,而 K.clear_session() 会有所帮助吗?

我的代码结构的最小示例:

from tensorflow.keras.models import Model
K = tf.keras.backend

def myModel():
**some architecture**

ip = Input(shape=(h, w, 3))
op = myModel(ip)
model = Model(ip, op)
model.compile(optimizer=Adam(lr=1e-6), loss=custom_mean_squared_error)

for e in range(numEpochs):
for batch in range(0, num_train_batches):
x = readImages()
y = readLabels()
loss = model.train_on_batch(x, y)


for batch in range(0, num_val_batches):
x = readImages()
y = model.predict(x)
val_loss = K.get_value(custom_mean_squared_error(x,y))
# save predictions

# plot training vs validation loss

Tensorflow-gpu-1.14,python3.6。如果我也做错了什么,将不胜感激。

最佳答案

这似乎对我有用,但会使过程变慢:

from tensorflow.keras.models import Model
K = tf.keras.backend

def myModel():
**some architecture**

ip = Input(shape=(h, w, 3))
op = myModel(ip)
model = Model(ip, op)
model.compile(optimizer=Adam(lr=1e-6), loss=custom_mse)

for e in range(numEpochs):
for batch in range(0, num_train_batches):
x = readImages()
y = readLabels()
# define appropriate flags for first loop
model = tf.keras.models.load_model(model_path,custom_objects={ 'custom_mse': custom_mse } )
loss = model.train_on_batch(x, y)

model.save(model_path)

for batch in range(0, num_val_batches):
x = readImages()
model = tf.keras.models.load_model(model_path,custom_objects={ 'custom_mse': custom_mse } )
y = model.predict(x)
K.clear_session()
val_loss = K.get_value(custom_mean_squared_error(x,y))
# save predictions

# plot training vs validation loss

关于python - 如何使用 K.clear_session() 修复 Keras 中的内存泄漏问题?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63258249/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com