gpt4 book ai didi

python - PyTorch flatten 不保持批量大小

转载 作者:行者123 更新时间:2023-12-04 11:18:39 26 4
gpt4 key购买 nike

在 Keras 中,使用 Flatten()层保留批量大小。例如,如果 Flatten 的输入形状是 (32, 100, 100) , 在 Keras Flatten 的输出是 (32, 10000) ,但在 PyTorch 中是 320000 .为什么会这样?

最佳答案

正如 OP 在他们的回答中已经指出的那样,张量操作不会默认考虑批量维度。您可以使用 torch.flatten() Tensor.flatten() start_dim=1在批量维度之后开始展平操作。

或者,从 PyTorch 1.2.0 开始,您可以定义 nn.Flatten() 模型中的图层默认为 start_dim=1 .

关于python - PyTorch flatten 不保持批量大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60115633/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com