gpt4 book ai didi

machine-learning - 调用向后时 nn.CDivTable 抛出错误是否有正当理由?

转载 作者:行者123 更新时间:2023-11-30 09:51:52 25 4
gpt4 key购买 nike

我最近开始使用 Torch 框架和 Lua 脚本语言来研究神经网络。我已经掌握了线性网络的基础知识,所以我尝试了一些更复杂但足够简单的东西:

这个想法是,我有 3 个输入,我必须选择前两个,将它们相除,然后将结果转发到线性模块。所以,我制作了这个小脚本:

require "nn";
require "optim";

local N = 3;

local input = torch.Tensor{
{1, 2, 3},
{9, 20, 20},
{9, 300, 1},
};

local output = torch.Tensor(N);
for i=1, N do
output[i] = 1;
end

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

local criterion = nn.BCECriterion();
local params, gradParams = ratioPerceptron:getParameters();
local optimState = {learningRate = 0.01};

local maxIteration = 100000;
for i=1, maxIteration do
local function f(params)
gradParams:zero();

local outputs = ratioPerceptron:forward(input);
local loss = criterion:forward(outputs, output);
local dloss_doutputs = criterion:backward(outputs, output);
ratioPerceptron:backward(input, dloss_doutputs);

return loss, gradParams;
end

optim.sgd(f, params, optimState);
end

当在训练期间调用向后并出现错误时,此操作会失败:

CDivTable.lua:21: both torch.LongStorage and (null) have no addition operator

但是,如果我从顺序模块中删除 CDivTable,并将 nn.Reshape 和 nn.Linear 更改为二维输入(因为我们删除了 CDivTable,它划分两维输入以产生一维输出),如下所示:

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.Reshape(N, 2));
ratioPerceptron:add(nn.Linear(2, 1));
ratioPerceptron:add(nn.Sigmoid());

训练完成,没有错误...是否有其他方法来划分两个选定的输入并将结果转发到线性模块?

最佳答案

模块CDivTable将一个表作为输入,并将第一个表的元素除以第二个表的元素。在这里,您可以将单个输入作为网络的输入,而不是两个输入的表。我相信这就是为什么你会出现 null 错误的原因。 Torch 无法理解您的输入(由两个向量组成)应被视为两个向量的表。它只能看到大小为 2x3 的张量!因此,您必须告诉 Torch 根据输入创建一个表。因此,您可以使用模块 SplitTable(dim) 将输入沿维度 dim 拆分为表。

在窄模块后面插入此行 ratioPerceptron:add(nn.SplitTable(1)):

local ratioPerceptron = nn.Sequential();
ratioPerceptron:add(nn.Narrow(1, 1, 2));
ratioPerceptron:add(nn.SplitTable(1))
ratioPerceptron:add(nn.CDivTable());
ratioPerceptron:add(nn.Reshape(N, 1));
ratioPerceptron:add(nn.Linear(1, 1));
ratioPerceptron:add(nn.Sigmoid());

此外,当您遇到此类错误时,我建议您通过放置 print 语句来查看网络计算的内容:插入一行 print(ratioPerceptron:forward(input)) 在添加会产生错误的模块的行之前。

关于machine-learning - 调用向后时 nn.CDivTable 抛出错误是否有正当理由?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43680417/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com