作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我最近阅读 this paper它引入了一个称为“预热”(WU)的过程,该过程包括将 KL 散度中的损失乘以一个变量,该变量的值取决于 epoch 的数量(它从 0 到 1 线性演化)
我想知道这是否是做到这一点的好方法:
beta = K.variable(value=0.0)
def vae_loss(x, x_decoded_mean):
# cross entropy
xent_loss = K.mean(objectives.categorical_crossentropy(x, x_decoded_mean))
# kl divergence
for k in range(n_sample):
epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
std=1.0) # used for every z_i sampling
# Sample several layers of latent variables
for mean, var in zip(means, variances):
z_ = mean + K.exp(K.log(var) / 2) * epsilon
# build z
try:
z = tf.concat([z, z_], -1)
except NameError:
z = z_
except TypeError:
z = z_
# sum loss (using a MC approximation)
try:
loss += K.sum(log_normal2(z_, mean, K.log(var)), -1)
except NameError:
loss = K.sum(log_normal2(z_, mean, K.log(var)), -1)
print("z", z)
loss -= K.sum(log_stdnormal(z) , -1)
z = None
kl_loss = loss / n_sample
print('kl loss:', kl_loss)
# result
result = beta*kl_loss + xent_loss
return result
# define callback to change the value of beta at each epoch
def warmup(epoch):
value = (epoch/10.0) * (epoch <= 10.0) + 1.0 * (epoch > 10.0)
print("beta:", value)
beta = K.variable(value=value)
from keras.callbacks import LambdaCallback
wu_cb = LambdaCallback(on_epoch_end=lambda epoch, log: warmup(epoch))
# train model
vae.fit(
padded_X_train[:last_train,:,:],
padded_X_train[:last_train,:,:],
batch_size=batch_size,
nb_epoch=nb_epoch,
verbose=0,
callbacks=[tb, wu_cb],
validation_data=(padded_X_test[:last_test,:,:], padded_X_test[:last_test,:,:])
)
最佳答案
这是行不通的。我对其进行了测试,以找出它无法正常工作的确切原因。要记住的关键是 Keras 在训练开始时创建了一次静态图。
因此,vae_loss
函数仅被调用一次以创建损失张量,这意味着对 beta
的引用每次计算损失时,变量将保持不变。但是,您的 warmup
函数将 beta 重新分配给新的 K.variable
.因此,beta
用于计算损失的是不同的 beta
比更新的那个,并且该值将始终为 0。
这是一个简单的修复。只需在您的 warmup
中更改这一行打回来:beta = K.variable(value=value)
到:K.set_value(beta, value)
这样实际值在beta
“就地”更新而不是创建新变量,并且将正确重新计算损失。
关于deep-learning - 变分自编码器 : implementing warm-up in Keras,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42787181/
COW 不是奶牛,是 Copy-On-Write 的缩写,这是一种是复制但也不完全是复制的技术。 一般来说复制就是创建出完全相同的两份,两份是独立的: 但是,有的时候复制这件事没多大必要
我是一名优秀的程序员,十分优秀!