gpt4 book ai didi

python - 带参数的自定义激活

转载 作者:太空宇宙 更新时间:2023-11-03 13:57:56 24 4
gpt4 key购买 nike

我正在尝试在 Keras 中创建一个激活函数,它可以接收参数 beta,如下所示:

from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation

class Swish(Activation):

def __init__(self, activation, beta, **kwargs):
super(Swish, self).__init__(activation, **kwargs)
self.__name__ = 'swish'
self.beta = beta


def swish(x):
return (K.sigmoid(beta*x) * x)

get_custom_objects().update({'swish': Swish(swish, beta=1.)})

它在没有 beta 参数的情况下运行良好,但我如何才能在激活定义中包含该参数?我还希望在执行 model.to_json() 时保存此值,例如 ELU 激活。


更新:我根据@today的回答写了下面的代码:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)
self.__name__ = 'swish'

def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs

def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
return input_shape

from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
arch_file.write(arch)

但是,它目前不在 .json 文件中保存 beta 值。如何让它保值?

最佳答案

既然序列化模型的时候要保存激活函数的参数,我觉得还是把激活函数定义成一个层比较好,比如advanced activations which have been defined in Keras .你可以这样做:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)

def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs

def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
return input_shape

然后您可以像使用 Keras 层一样使用它:

# ...
model.add(Swish(beta=0.3))

由于 get_config() 方法已在其定义中实现,因此在使用 to_json()保存()

关于python - 带参数的自定义激活,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53050448/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com