gpt4 book ai didi

python - Tensorflow 1.14+ 序列化子类 Keras 层?

转载 作者:行者123 更新时间:2023-11-30 09:41:45 25 4
gpt4 key购买 nike

我已经仔细阅读并重新阅读了 Tensorflow Keras 文档,例如

我有一个简单的子类化层:

class SimpleLayer(tf.keras.layers.Layer):
def __init__(self, filters, kernel_size, **kwargs):
super(SimpleLayer, self).__init__()
self.filters = filters
self.kernel_size = kernel_size
self.c1 = tf.keras.layers.Conv1D(filters, kernel_size, padding='same', activation='relu')
self.c2 = tf.keras.layers.Conv1D(filters, kernel_size, padding='same')

def call(self, inputs):
x = inputs
x = self.c1(x)
x = self.c2(x)
return x

def get_config(self):
# config = super(tf.keras.layers.Layer, self).get_config()
config = {}
config.update({
'filters': self.filters,
'kernel_size': self.kernel_size,
})
return config

然后有一个功能模型:


x = tf.keras.Inputs(...)

# some keras layers
y = tf.keras.layers... (x)

# my keras layer
y = SimpleLayer(...)(y)

# some keras layers
y = tf.keras.layers... (y)
y = tf.keras.layers.Dense(1)(y)

model = tf.keras.Model(inputs=x, outputs=y)
model.compile(...)

model.fit(...)

model.save('model.h5')

然后我可以将模型加载为:

tf.keras.models.load_model('model.h5')

但我得到:

ValueError: Unknown layer: SimpleLayer

来自 docs :

If you need your custom layers to be serializable as part of a Functional model, you can optionally implement a get_config method

我有。

我做错了什么?

最佳答案

您需要在加载过程中告诉 keras 您的自定义层,您可以使用 custom_objects 参数来执行此操作:

tf.keras.models.load_model('model.h5', custom_objects = {'SimpleLayer': SimpleLayer})

关于python - Tensorflow 1.14+ 序列化子类 Keras 层?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57792660/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com