gpt4 book ai didi

Skorch:帮助构建多个输出的分类器

转载 作者:行者123 更新时间:2023-12-05 03:00:05 41 4
gpt4 key购买 nike

我正在尝试通过翻译一个简单的 pytorch 模型来学习 skorch,该模型预测一组 MNIST 多位数字图片中包含的 2 位数字。这些图片包含 2 个重叠的数字,它们是输出标签 (y)。我收到以下错误:

ValueError: Stratified CV requires explicitely passing a suitable y

我遵循“MNIST with SciKit-Learn and skorch”笔记本并通过创建自定义 get_loss 函数应用了“Multiple return values from forward”中概述的多个输出修复。

数据维度为:

  • X: (40000, 1, 4, 28)
  • y: (40000, 2)

代码:

class Flatten(nn.Module):
"""A custom layer that views an input as 1D."""

def forward(self, input):
return input.view(input.size(0), -1)


class CNN(nn.Module):

def __init__(self):
super(CNN, self).__init__()

self.conv1 = nn.Conv2d(1, 32, 3)
self.pool1 = nn.MaxPool2d((2, 2))
self.conv2 = nn.Conv2d(32, 64, 3)
self.pool2 = nn.MaxPool2d((2, 2))
self.flatten = Flatten()
self.fc1 = nn.Linear(2880, 64)
self.drop1 = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(64, 10)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.drop1(x)
out_first_digit = self.fc2(x)
out_second_digit = self.fc3(x)

return out_first_digit, out_second_digit


torch.manual_seed(0)

class CNN_net(NeuralNetClassifier):
def get_loss(self, y_pred, y_true, *args, **kwargs):

loss1 = F.cross_entropy(y_pred[0], y_true[:,0])
loss2 = F.cross_entropy(y_pred[1], y_true[:,1])

return 0.5 * (loss1 + loss2)

net = CNN_net(
CNN,
max_epochs=5,
lr=0.1,
device=device,
)

net.fit(X_train, y_train);
  1. 需要修改y的格式吗?
  2. 我是否需要构建额外的自定义函数(预测)?
  3. 还有其他建议吗?

最佳答案

skorch 的 NeuralNetClassifier 默认应用分层交叉验证拆分,为您提供训练期间的验证准确性等指标。当然,这使得您的数据可以以这种方式拆分是必要的。由于每个图像都有两个标签,因此没有简单的方法来进行分层拆分(尽管有 are ways )。

我想到了两个解决方案:

  1. 完全禁用训练拆分(通过 train_split=None)并在训练期间失去验证
  2. 通过传递 train_split=skorch.dataset.CVSplit(5, stratified=False)
  3. 将训练拆分更改为非分层

因为我猜你在训练期间需要验证指标,所以你的最终代码应该如下所示:

net = CNN_net(
CNN,
max_epochs=5,
lr=0.1,
device=device,
train_split=skorch.dataset.CVSplit(5, stratified=False),
)

net.fit(X_train, y_train);

关于Skorch:帮助构建多个输出的分类器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57222733/

41 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com