gpt4 book ai didi

python - 具有不同时间步长的 RNN 的 Keras 掩蔽

转载 作者:IT老高 更新时间:2023-10-28 20:55:43 27 4
gpt4 key购买 nike

我正在尝试使用具有不同时间长度的序列在 Keras 中拟合 RNN。我的数据位于格式为 (sample, time, feature) = (20631, max_time, 24) 的 Numpy 数组中,其中 max_time 在运行时确定为时间戳最多的样本可用的时间步长。我已经用 0 填充了每个时间序列的开头,显然最长的除外。

我最初是这样定义我的模型的......

model = Sequential()
model.add(Masking(mask_value=0., input_shape=(max_time, 24)))
model.add(LSTM(100, input_dim=24))
model.add(Dense(2))
model.add(Activation(activate))
model.compile(loss=weibull_loglik_discrete, optimizer=RMSprop(lr=.01))
model.fit(train_x, train_y, nb_epoch=100, batch_size=1000, verbose=2, validation_data=(test_x, test_y))

为了完整起见,这里是损失函数的代码:

def weibull_loglik_discrete(y_true, ab_pred, name=None):
y_ = y_true[:, 0]
u_ = y_true[:, 1]
a_ = ab_pred[:, 0]
b_ = ab_pred[:, 1]

hazard0 = k.pow((y_ + 1e-35) / a_, b_)
hazard1 = k.pow((y_ + 1) / a_, b_)

return -1 * k.mean(u_ * k.log(k.exp(hazard1 - hazard0) - 1.0) - hazard1)

这是自定义激活函数的代码:

def activate(ab):
a = k.exp(ab[:, 0])
b = k.softplus(ab[:, 1])

a = k.reshape(a, (k.shape(a)[0], 1))
b = k.reshape(b, (k.shape(b)[0], 1))

return k.concatenate((a, b), axis=1)

当我拟合模型并做出一些测试预测时,测试集中的每个样本都会得到完全相同的预测,这似乎很可疑。

如果我移除 mask 层,情况会好转,这让我觉得 mask 层有问题,但据我所知,我完全按照文档操作。

掩蔽层是否有错误指定?我还缺少其他东西吗?

最佳答案

您实现屏蔽的方式应该是正确的。如果您有形状为 (samples, timesteps, features) 的数据,并且您想使用与 features 参数大小相同的零掩码来掩盖缺少数据的时间步,则添加 Masking (mask_value=0., input_shape=(timesteps, features)).见这里:keras.io/layers/core/#masking

您的模型可能过于简单,和/或您的时期数可能不足以使模型区分所有类。试试这个模型:

model = Sequential()
model.add(Masking(mask_value=0., input_shape=(max_time, 24)))
model.add(LSTM(256, input_dim=24))
model.add(Dense(1024))
model.add(Dense(2))
model.add(Activation(activate))
model.compile(loss=weibull_loglik_discrete, optimizer=RMSprop(lr=.01))
model.fit(train_x, train_y, nb_epoch=100, batch_size=1000, verbose=2, validation_data=(test_x, test_y))

如果这不起作用,请尝试将 epoch 加倍几次(例如 200、400),看看这是否会改善结果。

关于python - 具有不同时间步长的 RNN 的 Keras 掩蔽,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42353056/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com