gpt4 book ai didi

Keras:具有类权重的 LSTM

转载 作者:行者123 更新时间:2023-12-01 09:52:55 24 4
gpt4 key购买 nike

我的问题与 this question 密切相关但也超越了它。

我正在尝试实现以下 LSTM在 Keras 那里

  • 时间步数为 nb_tsteps=10
  • 输入特征的数量是 nb_feat=40
  • 每个时间步长 LSTM 单元的数量为 120
  • LSTM 层之后是 TimeDistributedDense 层

  • 从上面引用的问题中,我明白我必须将输入数据呈现为
    nb_samples, 10, 40
    我在哪里得到 nb_samples通过滚动一个长度的窗口 nb_tsteps=10跨越形状的原始时间序列 (5932720, 40) .代码因此
    model = Sequential()
    model.add(LSTM(120, input_shape=(X_train.shape[1], X_train.shape[2]),
    return_sequences=True, consume_less='gpu'))
    model.add(TimeDistributed(Dense(50, activation='relu')))
    model.add(Dropout(0.2))
    model.add(TimeDistributed(Dense(20, activation='relu')))
    model.add(Dropout(0.2))
    model.add(TimeDistributed(Dense(10, activation='relu')))
    model.add(Dropout(0.2))
    model.add(TimeDistributed(Dense(3, activation='relu')))
    model.add(TimeDistributed(Dense(1, activation='sigmoid')))

    现在我的问题(假设以上是正确的):
    二元响应 (0/1) 严重不平衡,我需要传递 class_weight字典如 cw = {0: 1, 1: 25}model.fit() .但是我得到一个异常(exception) class_weight not supported for 3+ dimensional targets .这是因为我将响应数据显示为 (nb_samples, 1, 1) .如果我将其 reshape 为二维数组 (nb_samples, 1)我收到异常 Error when checking model target: expected timedistributed_5 to have 3 dimensions, but got array with shape (5932720, 1) .

    非常感谢您的帮助!

    最佳答案

    我认为你应该使用 sample_weightsample_weight_mode='temporal' .

    来自 Keras 文档:

    sample_weight: Numpy array of weights for the training samples, used for scaling the loss function (during training only). You can either pass a flat (1D) Numpy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile().



    在您的情况下,您需要提供与标签形状相同的二维数组。

    关于Keras:具有类权重的 LSTM,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38891390/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com