gpt4 book ai didi

python - 使用(常量)参数保存/加载 Keras 模型

转载 作者:行者123 更新时间:2023-12-05 07:09:56 26 4
gpt4 key购买 nike

我的情况与 "Save/load a keras model with constants" 相似,但又略有不同。

我正在 tf.keras(TFv1.12,是的,我知道)中创建对象检测模型(基于 YOLO),其原始输出需要后处理为边界框.

这涉及到一些参数,这些参数对于模型的目的来说是不变的,但是构建模型的脚本的参数:例如类的数量,以及生成相对于框的“ anchor ”位置的矩阵。

我的模型将被加载到一个 TFServing 容器中,所以我试图确保:

  1. 转换被封装在模型中,而不是让用户去做或分离出后处理逻辑
  2. 保存的模型工件(例如 Keras h5 或 TF pb+params)足以加载和提供模型

正确的做法是什么?

据我所知,以下不起作用:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Lambda, Layer

# Lambda layer using closure scope triggers error when trying to load the model:
# seems like `param` is defined but some weird object
def make_output_lambda(param):
def mylambda(raw_output):
return raw_output + param
return Lambda(mylambda)

# Even if the custom layer type is added to `custom_objects` on
# `tf.keras.models.load_model()` - it seems to get called without the positional
# arguments:
class MyCustomLayer(Layer):
def __init__(self, param, **kwargs):
super(YOLOHeadLayer, self).__init__(**kwargs)
self.param = param # or K.constant(param) - same overall problem

def call(self, inputs):
return inputs + self.param

# Keras throws an error when creating a `Model` that depends on a constant
# tensor which isn't an `Input` (and who wants a constant "Input"?)
def lambdatwo(inputs):
return inputs[0] + inputs[1]
param_tensor = K.constant(param)
y = Lambda(lambdatwo)((raw_output, param_tensor))

最佳答案

你应该使用 add_weight来自继承的 Layer 类的方法,使用 trainable=False 标志来避免更新常量:

class MyCustomLayer(Layer):
def __init__(self, param, **kwargs):
super(MyCustomLayer, self).__init__(**kwargs)
self.param = self.add_weight(
shape=tf.shape(param),
initializer=lambda _: param,
trainable=False
)

def call(self, inputs):
return inputs + self.param

关于python - 使用(常量)参数保存/加载 Keras 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61400273/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com