gpt4 book ai didi

c++ - 尺寸不匹配 CNN LibTorch/PyTorch

转载 作者:行者123 更新时间:2023-12-01 14:53:14 27 4
gpt4 key购买 nike

我在 LibTorch 中有一个 CNN 结构,但尺寸不正确。我的目标是输入 3 channel 64x64 图像并为 DGAN 输出逻辑回归浮点数。最后一层我设置为输入 channel 36,因为如果我移除该层,输出神经元的维度为 6x6,所以我猜这是全连接输入所需的维度。我想知道:

  • 您通常如何检查 LibTorch 或 Pytorch 中的尺寸(即检查最后一个模块所需的大小,检查每层有多少可训练参数...)
  • 这种情况下的错误是什么
  •    #include <torch/torch.h>
    #include "parameters.h"
    using namespace torch;

    class DCGANDiscriminatorImpl: public nn::Module {

    private:
    nn::Conv2d conv1, conv2, conv3, conv4;
    nn::BatchNorm2d batch_norm1, batch_norm2;
    nn::Linear fc1;

    public:
    DCGANDiscriminatorImpl()
    :conv1(nn::Conv2dOptions(3, 64, 4).stride(2).padding(1).bias(false)),

    conv2(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),

    batch_norm1(128),

    conv3(nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),

    batch_norm2(256),

    conv4(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),

    fc1(6*6, 1)

    {
    register_module("conv1", conv1);
    register_module("conv2", conv2);
    register_module("conv3", conv3);
    register_module("conv4", conv4);
    register_module("batch_norm1", batch_norm1);
    register_module("batch_norm2", batch_norm2);
    register_module("fc1", fc1);

    }

    Tensor forward(torch::Tensor x)
    {
    x = leaky_relu(conv1(x), cte::NEGATIVE_SLOPE);
    x = leaky_relu(batch_norm1(conv2(x)), cte::NEGATIVE_SLOPE);
    x = leaky_relu(batch_norm2(conv3(x)), cte::NEGATIVE_SLOPE);
    x = sigmoid(fc1(x));
    return x;
    }

    };

    TORCH_MODULE(DCGANDiscriminator);

    我得到的错误是:
    libc++abi.dylib: terminating with uncaught exception of type std::runtime_error: size mismatch, m1: [131072 x 8], m2: [36 x 1] at ../aten/src/TH/generic/THTensorMath.cpp:136

    最佳答案

    我有几个问题,但最后这个架构奏效了。

    using namespace torch;

    class DCGANDiscriminatorImpl: public nn::Module {

    private:
    nn::Conv2d conv1, conv2, conv3, conv4;
    nn::BatchNorm2d batch_norm1, batch_norm2;
    nn::Linear fc1;

    public:
    DCGANDiscriminatorImpl()
    :conv1(nn::Conv2dOptions(3, 64, 4).stride(2).padding(1).bias(false)),

    conv2(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),

    batch_norm1(128),

    conv3(nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),

    batch_norm2(256),

    conv4(nn::Conv2dOptions(256, 64, 3).stride(1).padding(0).bias(false)),

    fc1(6*6*64, 1)

    {
    register_module("conv1", conv1);
    register_module("conv2", conv2);
    register_module("conv3", conv3);
    register_module("conv4", conv4);
    register_module("batch_norm1", batch_norm1);
    register_module("batch_norm2", batch_norm2);
    register_module("fc1", fc1);

    }

    Tensor forward(torch::Tensor x)
    {
    x = leaky_relu(conv1(x), cte::NEGATIVE_SLOPE);
    x = leaky_relu(batch_norm1(conv2(x)), cte::NEGATIVE_SLOPE);
    x = leaky_relu(batch_norm2(conv3(x)), cte::NEGATIVE_SLOPE);
    x = leaky_relu(conv4(x), cte::NEGATIVE_SLOPE);
    x = x.view({x.size(0), -1});
    x = sigmoid(fc1(x));
    return x;
    }

    };

    TORCH_MODULE(DCGANDiscriminator);

    关于c++ - 尺寸不匹配 CNN LibTorch/PyTorch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60826846/

    27 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com