gpt4 book ai didi

python - Keras Lambda 层的内存泄漏

转载 作者:行者123 更新时间:2023-12-04 02:31:10 24 4
gpt4 key购买 nike

我需要拆分张量的 channel ,以便为每个拆分应用不同的归一化。为此,我使用了 Keras 的 Lambda 层:

# split the channels in two (first part for IN, second for BN)
x_in = Lambda(lambda x: x[:, :, :, :split_index])(x)
x_bn = Lambda(lambda x: x[:, :, :, split_index:])(x)

# apply IN and BN on their respective group of channels
x_in = InstanceNormalization(axis=3)(x_in)
x_bn = BatchNormalization(axis=3)(x_bn)

# concatenate outputs of IN and BN
x = Concatenate(axis=3)([x_in, x_bn])

一切都按预期工作(请参阅下面的 model.summary()),但 RAM 在每次迭代中不断增加,表明存在内存泄漏。

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer) (None, 832, 832, 1) 0
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 832, 832, 32) 320 input_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 832, 832, 16) 0 conv1[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 832, 832, 16) 0 conv1[0][0]
__________________________________________________________________________________________________
instance_normalization_1 (Insta (None, 832, 832, 16) 32 lambda_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 832, 832, 16) 64 lambda_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 832, 832, 32) 0 instance_normalization_1[0][0]
batch_normalization_1[0][0]
__________________________________________________________________________________________________

我确信泄漏来自 Lambda 层,因为我尝试了另一种策略,我不拆分而是在所有 channel 上独立应用两个规范化,然后将特征添加在一起。我没有遇到此代码的任何内存泄漏:

# apply IN and BN on the input tensor independently
x_in = InstanceNormalization(axis=3)(x)
x_bn = BatchNormalization(axis=3)(x)

# addition of the feature maps outputed by IN and BN
x = Add()([x_in, x_bn])

有解决此内存泄漏的想法吗?我正在使用 Keras 2.2.4 和 Tensorflow 1.15.3,我现在无法升级到 TF 2 或 tf.keras。

最佳答案

Thibault Bacqueyrisses答案是正确的,内存泄漏随着自定义层消失了!

这是我的实现:

class Crop(keras.layers.Layer):
def __init__(self, dim, start, end, **kwargs):
"""
Slice the tensor on the last dimension, keeping what is between start
and end.
Args
dim (int) : dimension of the tensor (including the batch dim)
start (int) : index of where to start the cropping
end (int) : index of where to stop the cropping
"""
super(Crop, self).__init__(**kwargs)
self.dimension = dim
self.start = start
self.end = end

def call(self, inputs):
if self.dimension == 0:
return inputs[self.start:self.end]
if self.dimension == 1:
return inputs[:, self.start:self.end]
if self.dimension == 2:
return inputs[:, :, self.start:self.end]
if self.dimension == 3:
return inputs[:, :, :, self.start:self.end]
if self.dimension == 4:
return inputs[:, :, :, :, self.start:self.end]

def compute_output_shape(self, input_shape):
return (input_shape[:-1] + (self.end - self.start,))

def get_config(self):
config = {
'dim': self.dimension,
'start': self.start,
'end': self.end,
}
base_config = super(Crop, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

关于python - Keras Lambda 层的内存泄漏,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64139658/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com