gpt4 book ai didi

torch - PyTorch 运行时错误 : Assertion `cur_target >= 0 && cur_target < n_classes' failed

转载 作者:行者123 更新时间:2023-12-01 20:05:31 26 4
gpt4 key购买 nike

我正在尝试在 Pytorch 中创建一个基本的二元分类器,用于对我的玩家在 Pong 游戏中是在右侧还是左侧进行分类。输入是 1x42x42 图像,标签是我的玩家一侧(右 = 1 或左 = 2)。代码:

class Net(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

net = Net(42 * 42, 100, 2)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer_net = torch.optim.Adam(net.parameters(), 0.001)
net.train()

while True:
state = get_game_img()
state = torch.from_numpy(state)

# right = 1, left = 2
current_side = get_player_side()
target = torch.LongTensor(current_side)
x = Variable(state.view(-1, 42 * 42))
y = Variable(target)
optimizer_net.zero_grad()
y_pred = net(x)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()

我得到的错误:

  File "train.py", line 109, in train
loss = criterion(y_pred, y)
File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
result = self.forward(*input, **kwargs)
File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 321, in forward
self.weight, self.size_average)
File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 533, in cross_entropy
return nll_loss(log_softmax(input), target, weight, size_average)
File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 501, in nll_loss
return f(input, target)
File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward
output, *self.additional_args)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at /py/conda-bld/pytorch_1493676237139/work/torch/lib/THNN/generic/ClassNLLCriterion.c:57

最佳答案

对于大多数深度学习库,目标(或标签)应该从 0 开始。

这意味着你的目标应该在 [0,n) 范围内,有 n 个类别。

关于torch - PyTorch 运行时错误 : Assertion `cur_target >= 0 && cur_target < n_classes' failed,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45769206/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com