gpt4 book ai didi

python - Keras:序列化掩蔽层以进行保存/加载

转载 作者:太空宇宙 更新时间:2023-11-04 04:50:45 24 4
gpt4 key购买 nike

所以我在 Keras 中有一个自定义层,其中使用了一个 mask 。

为了让它与保存/加载一起工作,我需要正确地序列化 Mask。所以这个标准代码不起作用:

def get_config(self):
config = {'mask': self.mask}
base_config = super(Mixing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

其中掩码是对掩码层的引用。

我不确定如何序列化 Masking(或一般的 Keras Layers)。谁能帮忙?

最佳答案

您可以实现相同的 serializing methods作为内置的 Wrapper 类。

def get_config(self):
config = {'layer': {'class_name': self.layer.__class__.__name__,
'config': self.layer.get_config()}}
base_config = super(Wrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
from . import deserialize as deserialize_layer
layer = deserialize_layer(config.pop('layer'),
custom_objects=custom_objects)
return cls(layer, **config)

序列化时,在get_config中,将内层的类名和config保存在config['layer']中。

from_config中,内层使用config['layer']通过deserialize_layer反序列化。

关于python - Keras:序列化掩蔽层以进行保存/加载,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48391265/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com