gpt4 book ai didi

python - Keras SimpleRNN/LSTM 默认使用哪个轴作为时间轴?

转载 作者:行者123 更新时间:2023-12-01 13:10:14 26 4
gpt4 key购买 nike

当使用 SimpleRNNLSTM 进行经典处理时 sentiment analysis算法(此处适用于长度 <= 250 个单词/标记的句子):

model = Sequential()
model.add(Embedding(5000, 32, input_length=250)) # Output shape: (None, 250, 32)
model.add(SimpleRNN(100)) # Output shape: (None, 100)
model.add(Dense(1, activation='sigmoid')) # Output shape: (None, 1)

哪里指定了RNN输入的哪个轴作为“时间”轴?

更准确地说,在 Embedding 层之后,给定的输入句子,例如“the cat sat on the mat”,被编码为形状为 (250, 32) 的矩阵 x,其中 250 是输入的最大长度(以字为单位)文本和 32 嵌入的维度。然后,在 Keras 的哪个位置指定是否使用它:

  1. h[t] = 激活(W_h * x[:, t] + U_h * h[t-1] + b_h )

或者这个:

  1. h[t] = 激活(W_h * x[t, :] + U_h * h[t-1] + b_h)

(在这两种情况下,y[t] = activation( W_y * h[t] + b_y ))

TL;DR:如果 RNN Keras 层的输入大小为 (250, 32),默认情况下它使用哪个轴作为时间轴? Keras 或 Tensorflow 文档中的何处对此进行了详细说明?


PS:如何解释参数个数(由model.summary()给出)是13300? W_h 有 100x32 系数,U_h 有 100x100 系数,b_h 有 100x1 系数,即我们已经有 13300! W_yb_y 没有系数了!怎么解释呢?

最佳答案

时间轴:一直为dim 1,除非time_major=True,否则为dim 2; Embedding 层输出一个 3D 张量。这个可以看here其中 step_input_shaperecurrent loop 中每一步输入到 RNN cell 的形状.对于您的情况,timesteps=250,并且 SimpleRNN 单元在每一步“看到”一个形状为 (batch_size, 32) 的张量。


# of params:您可以通过检查每一层的 .build() 代码来了解图形的派生方式:Embedding , SimpleRNN , Dense ,或者同样在每一层上调用 .weights。对于您的情况,w/l = model.layers[1]:

  • l.weights[0].shape == (32, 100) --> 3200 个参数(内核)
  • l.weights[1].shape == (100, 100) --> 10000 个参数 (recurrent_kernel)
  • l.weights[2].shape == (100,) --> 100 个参数 (bias) (sum: 13,300 )

计算逻辑:没有W_yb_y; “y”是隐藏状态,h,实际上适用于所有循环层 - 您引用的内容可能来自通用 RNN 公式。 @“在这两种情况下……”——这是错误的;要查看实际发生的情况,请检查 .call()代码。

附言我建议定义模型的完整 batch_shape 以进行调试,因为它消除了模棱两可的 None 形状


SimpleRNN 公式与代码:根据要求;请注意源代码中的 h 具有误导性,在公式(“预激活”)中通常为 z

enter image description here

  • return_sequences=True -> 返回所有时间步的输出:(batch_size, timesteps, channels)

  • return_sequences=False -> 仅返回最后时间步的输出:(batch_size, 1, channels)。参见 here

关于python - Keras SimpleRNN/LSTM 默认使用哪个轴作为时间轴?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60571934/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com