gpt4 book ai didi

python - 如何在 Keras 中正确实现自定义事件正则化器?

转载 作者:太空狗 更新时间:2023-10-29 21:11:56 25 4
gpt4 key购买 nike

我正在尝试根据 Andrew Ng 的讲义实现稀疏自动编码器,如图所示 here .它要求通过引入惩罚项(K-L 散度)在自动编码器层上应用稀疏约束。我尝试使用提供的方向来实现这个 here ,经过一些小的改动。下面是 SparseActivityRegularizer 类实现的 K-L 散度和稀疏惩罚项,如下所示。

def kl_divergence(p, p_hat):
return (p * K.log(p / p_hat)) + ((1-p) * K.log((1-p) / (1-p_hat)))

class SparseActivityRegularizer(Regularizer):
sparsityBeta = None

def __init__(self, l1=0., l2=0., p=-0.9, sparsityBeta=0.1):
self.p = p
self.sparsityBeta = sparsityBeta

def set_layer(self, layer):
self.layer = layer

def __call__(self, loss):
#p_hat needs to be the average activation of the units in the hidden layer.
p_hat = T.sum(T.mean(self.layer.get_output(True) , axis=0))

loss += self.sparsityBeta * kl_divergence(self.p, p_hat)
return loss

def get_config(self):
return {"name": self.__class__.__name__,
"p": self.l1}

模型是这样构建的

X_train = np.load('X_train.npy')
X_test = np.load('X_test.npy')

autoencoder = Sequential()
encoder = containers.Sequential([Dense(250, input_dim=576, init='glorot_uniform', activation='tanh',
activity_regularizer=SparseActivityRegularizer(p=-0.9, sparsityBeta=0.1))])

decoder = containers.Sequential([Dense(576, input_dim=250)])
autoencoder.add(AutoEncoder(encoder=encoder, decoder=decoder, output_reconstruction=True))
autoencoder.layers[0].build()
autoencoder.compile(loss='mse', optimizer=SGD(lr=0.001, momentum=0.9, nesterov=True))
loss = autoencoder.fit(X_train_tmp, X_train_tmp, nb_epoch=200, batch_size=800, verbose=True, show_accuracy=True, validation_split = 0.3)
autoencoder.save_weights('SparseAutoEncoder.h5',overwrite = True)
result = autoencoder.predict(X_test)

当我调用 fit() 函数时,我得到负损失值并且输出与输入完全不同。我想知道我哪里出错了。计算层的平均激活并使用此自定义稀疏正则化器的正确方法是什么?任何形式的帮助将不胜感激。谢谢!

我将 Keras 0.3.1 与 Python 2.7 结合使用,因为最新的 Keras (1.0.1) 版本没有自动编码器层。

最佳答案

您已经定义了 self.p = -0.9 而不是原始海报和您提到的讲义都使用的 0.05 值。

关于python - 如何在 Keras 中正确实现自定义事件正则化器?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36913281/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com