gpt4 book ai didi

python - 为什么当作为参数传递给 `input_shape` 层时, `Dense` 不包括批处理维度?

转载 作者:行者123 更新时间:2023-12-03 09:35:49 25 4
gpt4 key购买 nike

在 Keras 中,为什么是 input_shape当作为参数传递给像 Dense 这样的层时,不包括批处理维度但在 input_shape 时包含批处理维度传递给 build模型的方法?

import tensorflow as tf
from tensorflow.keras.layers import Dense

if __name__ == "__main__":
model1 = tf.keras.Sequential([Dense(1, input_shape=[10])])
model1.summary()

model2 = tf.keras.Sequential([Dense(1)])
model2.build(input_shape=[None, 10]) # why [None, 10] and not [10]?
model2.summary()
这是 API 设计的明智选择吗?如果是,为什么?

最佳答案

您可以通过几种不同的方式指定模型的输入形状。例如,通过向模型的第一层提供以下参数之一:

  • batch_input_shape :第一个维度是批量大小的元组。
  • input_shape : 不包括批大小的元组,例如,批大小假定为 Nonebatch_size ,如果指定。
  • input_dim : 一个标量,表示输入的维度。

  • 在所有这些情况下,Keras 都是 internally storing一个属性 _batch_input_size建立模型。
    关于 build方法,我的猜测是这确实是一个有意识的选择——关于批量大小的信息可能对在某些(也许是未曾想到的)情况下构建模型有用。因此,包含批处理维度作为输入的框架 build比没有的框架更通用和完整。尽管如此,我同意你将论点命名为 batch_input_shape而不是 input_shape将使一切更加一致。

    值得一提的是,用户很少需要调用 build自己的方法。这在需要时在内部发生。如今,甚至可以 ignore input_shape创建模型时的参数(尽管像 summary 这样的方法在模型构建之前将不起作用)。在这种情况下,Keras 能够从参数 x 推断输入形状。的 fit .

    关于python - 为什么当作为参数传递给 `input_shape` 层时, `Dense` 不包括批处理维度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64681232/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com