gpt4 book ai didi

python - 在 chainer 中使用数组作为 MNIST 数据的标签

转载 作者:太空宇宙 更新时间:2023-11-03 14:26:40 27 4
gpt4 key购买 nike

Python模块chainer有一个 introduction它使用神经网络来识别 MNIST database 中的手写数字。 。

假设特定的手写数字D.png被标记为3。我习惯了标签以数组形式出现,如下所示:

label = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]

但是,chainer 使用整数进行标签:

label = 3

数组标签对我来说更直观,因为输出预测也是一个数组。在不处理图像的神经网络中,我希望能够灵活地将标签指定为特定数组。

我直接从链接器介绍中包含了下面的代码。如果您解析 traintest 数据集,请注意所有标签都是整数而不是 float 。

如何使用数组作为标签而不是整数来运行训练/测试数据?

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

class MLP(Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# the size of the inputs to each layer will be inferred
self.l1 = L.Linear(None, n_units) # n_in -> n_units
self.l2 = L.Linear(None, n_units) # n_units -> n_units
self.l3 = L.Linear(None, n_out) # n_units -> n_out

def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
y = self.l3(h2)
return y

train, test = datasets.get_mnist()

train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

model = L.Classifier(MLP(100, 10)) # the input size, 784, is inferred
optimizer = optimizers.SGD()
optimizer.setup(model)

updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (20, 'epoch'), out='result')

trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()

最佳答案

分类器接受包含图像或其他数据的元组作为数组(float32)和标签作为int。这是 chainer 的约定及其工作原理。如果打印标签,您将看到您正在获得一个 dtype int 的数组。图像/非图像数据和标签都将位于数组中,但数据类型分别为 float 和 int。

所以回答你的问题:你的标签本身采用数组格式,dtype int(标签应该如此)。

如果您希望标签为 0 和 1,而不是 1 到 10,请使用 One Hot Encoding( https://blog.cambridgespark.com/robust-one-hot-encoding-in-python-3e29bfcec77e )。

关于python - 在 chainer 中使用数组作为 MNIST 数据的标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47595883/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com