gpt4 book ai didi

keras - keras 自定义层中的持久变量

转载 作者:行者123 更新时间:2023-12-03 23:30:55 26 4
gpt4 key购买 nike

我想编写一个自定义层,在其中我可以在两次运行之间将变量保存在内存中。例如,

class MyLayer(Layer):
def __init__(self, out_dim = 51, **kwargs):
self.out_dim = out_dim
super(MyLayer, self).__init__(**kwargs)

def build(self, input_shape):
a = 0.0
self.persistent_variable = K.variable(a)
self.built = True

def get_output_shape_for(self, input_shape):
return (input_shape[0], 1)

def call(self, x, mask=None):
a = K.eval(self.persistent_variable) + 1
K.set_value(self.persistent_variable, a)
return self.persistent_variable

m = Sequential()
m.add(MyLayer(input_shape=(1,)))

当我运行 m.predict 时,我希望 persistent_variable 得到更新,并打印增加的值。但它看起来总是打印 0

# Dummy input
x = np.zeros(1)

m.predict(x, batch_size=1)

我的问题是,如何在每次运行 m.predict

后使 persistent_variable 递增并保存

谢谢,纳文

最佳答案

诀窍是你必须在你的调用函数中调用 self.add_update(...) 来注册一个每次评估你的模型时都会调用的函数(我通过挖掘发现了这一点进入有状态 rnns 的源代码)。如果你这样做 self.stateful = True 它将为每个训练和预测调用调用你的自定义更新函数,否则它只会在训练期间调用它。例如:

import keras.backend as K
import numpy as np
from keras.engine.topology import Layer

class CounterLayer(Layer):
def __init__(self, stateful=False,**kwargs):
self.stateful = stateful # True means it will increment counter on predict and train, false means it will only increment counter on train
super(CounterLayer, self).__init__(**kwargs)


def build(self, input_shape):
# Define variables in build
self.count = K.variable(0, name="count")
super(CounterLayer, self).build(input_shape)

def call(self, x, mask=None):
updates = []
# The format is (variable, value setting to)
# So this says
# self.pos = self.pos + 1
updates.append((self.count, self.count+1))

# You can append more updates to this list or call add_update more
# times if you want

# Add our custom update

# We stick x here so it calls our update function every time our layer
# is given a new x
self.add_update(updates, x)

# This will be an identity layer but keras gets mad for some reason
# if you just output x so we'll multiply it by 1 so it thinks it is a
# "new variable"
return self.count
# in newer keras versions you might need to name this compute_output_shape instead
def get_output_shape_for(self, input_shape):
# We will just return our count as an array ([[count]])
return (1,1)

def reset_states(self):
self.count.set_value(0)

示例用法:

from keras.layers import Input
from keras.models import Model
from keras.optimizers import RMSprop
inputLayer = Input(shape=(10,))
counter = CounterLayer() # Don't update on predict
# counter = CounterLayer(stateful=True) # This will update each time you call predict
counterLayer = counter(inputLayer)
model = Model(input=inputLayer, output=counterLayer)
optimizer = RMSprop(lr=0.001)
model.compile(loss="mse", optimizer=optimizer)


# See the value of our counter
print counter.count.get_value()

# This won't actually train anything but each epoch will update our counter

# Note that if you say have a batch size of 5, update will be called 5 times per epoch
model.fit(np.zeros([1, 10]), np.array([0]), batch_size=1, nb_epoch=5)

# The value of our counter has now changed
print counter.count.get_value()

model.predict(np.zeros([1, 10]))

# If we did stateful=False, this didn't change, otherwise it did
print counter.count.get_value()

关于keras - keras 自定义层中的持久变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41645990/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com