gpt4 book ai didi

machine-learning - 在 CIFAR-10 数据集上使用深度网络进行分类

转载 作者:行者123 更新时间:2023-11-30 09:50:51 26 4
gpt4 key购买 nike

我正在尝试使用深度学习技术构建一个分类器,并使用 cifar-10 数据集来构建一个分类器。我尝试构建一个具有 1024 个隐藏节点的分类器。每个图像的大小为 32*32*3(R-G-B)。由于我的计算机处理能力较低,我只从数据集中的 3/5 文件加载了数据。

from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os
import sys
import tarfile
import random
from IPython.display import display, Image
from scipy import ndimage
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle
from sklearn.preprocessing import MultiLabelBinarizer

folder='/home/cifar-10-batches-py/'

training_data=np.ndarray((30000,3072),dtype=np.float32)
training_labels=np.ndarray(30000,dtype=np.int32)

testing_data=np.ndarray((10000,3072),dtype=np.float32)
testing_labels=np.ndarray(10000,dtype=np.int32)

no_of_files=3

begin=0
end=10000

for i in range(no_of_files):
with open(folder+"data_batch_"+str(i+1),'rb') as f:
s=pickle.load(f,encoding='bytes')
training_data[begin:end]=s[b'data']
training_labels[begin:end]=s[b'labels']
begin=begin+10000
end=end+10000

test_path='/home/cifar-10-batches-py/test_batch'
with open(test_path,'rb') as d:
s9=pickle.load(d,encoding='bytes')
tdata=s9[b'data']
testing_data=tdata
tlabels=s9[b'labels']
testing_labels=tlabels
test_data=np.ndarray((5000,3072),dtype=np.float32)
test_labels=np.ndarray(5000,dtype=np.int32)
valid_data=np.ndarray((5000,3072),dtype=np.float32)
valid_labels=np.ndarray(5000,dtype=np.int32)

valid_data[:,:]=testing_data[:5000, :]
valid_labels[:]=testing_labels[:5000]
test_data[:,:]=testing_data[5000:, :]
test_labels[:]=testing_labels[5000:]
onehot_training_labels=np.eye(10)[training_labels.astype(int)]
onehot_test_labels=np.eye(10)[test_labels.astype(int)]
onehot_valid_labels=np.eye(10)[valid_labels.astype(int)]
image_size=32
num_labels=10
train_subset = 10000

def accuracy(predictions, labels):
return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))
/ predictions.shape[0])

batch_size = 128
relu_count = 1024 #hidden nodes count

graph = tf.Graph()
with graph.as_default():
tf_train_dataset = tf.placeholder(tf.float32,
shape=(batch_size, image_size * image_size*3))
tf_train_labels = tf.placeholder(tf.float32, shape=(batch_size, num_labels))
tf_valid_dataset = tf.constant(valid_data)
tf_test_dataset = tf.constant(test_data)
beta_regul = tf.placeholder(tf.float32)

weights1 = tf.Variable(
tf.truncated_normal([image_size * image_size*3, relu_count]))
biases1 = tf.Variable(tf.zeros([relu_count]))
weights2 = tf.Variable(
tf.truncated_normal([relu_count, num_labels]))

biases2 = tf.Variable(tf.zeros([num_labels]))

preds = tf.matmul( tf.nn.relu(tf.matmul(tf_train_dataset, weights1) + biases1), weights2) + biases2


loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=tf_train_labels))+ \
beta_regul * (tf.nn.l2_loss(weights1) + tf.nn.l2_loss(weights2))

optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
train_prediction = tf.nn.softmax(preds)
lay1_valid = tf.nn.relu(tf.matmul(tf_valid_dataset, weights1) + biases1)
valid_prediction = tf.nn.softmax(tf.matmul(lay1_valid, weights2) + biases2)
lay1_test = tf.nn.relu(tf.matmul(tf_test_dataset, weights1) + biases1)
test_prediction = tf.nn.softmax(tf.matmul(lay1_test, weights2) + biases2)
num_steps = 5000

with tf.Session(graph=graph) as session:
tf.initialize_all_variables().run()
print("Initialized")
for step in range(num_steps):
offset = (step * batch_size) % (onehot_training_labels.shape[0] - batch_size)
batch_data = training_data[offset:(offset + batch_size), :]
batch_labels = onehot_training_labels[offset:(offset + batch_size), :]
feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels,beta_regul : 1e-3}
_, l, predictions = session.run(
[optimizer, loss, train_prediction], feed_dict=feed_dict)
if (step % 500 == 0):
print("Minibatch loss at step %d: %f" % (step, l))
print("Minibatch accuracy: %.1f%%" % accuracy(predictions, batch_labels))
print("Validation accuracy: %.1f%%" % accuracy(
valid_prediction.eval(), onehot_valid_labels))
print("Test accuracy: %.1f%%" % accuracy(test_prediction.eval(), onehot_test_labels))

此代码的输出是:

WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/util/tf_should_use.py:170: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
Initialized
Minibatch loss at step 0: 117783.914062
Minibatch accuracy: 14.8%
Validation accuracy: 10.2%
Minibatch loss at step 500: 3632989892247552.000000
Minibatch accuracy: 12.5%
Validation accuracy: 10.1%
Minibatch loss at step 1000: 2203224941527040.000000
Minibatch accuracy: 6.2%
Validation accuracy: 9.9%
Minibatch loss at step 1500: 1336172110413824.000000
Minibatch accuracy: 10.9%
Validation accuracy: 9.8%
Minibatch loss at step 2000: 810328996708352.000000
Minibatch accuracy: 8.6%
Validation accuracy: 10.1%
Minibatch loss at step 2500: 491423044468736.000000
Minibatch accuracy: 9.4%
Validation accuracy: 10.1%
Minibatch loss at step 3000: 298025566076928.000000
Minibatch accuracy: 12.5%
Validation accuracy: 9.8%
Minibatch loss at step 3500: 180741635833856.000000
Minibatch accuracy: 10.9%
Validation accuracy: 9.8%
Minibatch loss at step 4000: 109611013111808.000000
Minibatch accuracy: 15.6%
Validation accuracy: 10.1%
Minibatch loss at step 4500: 66473376612352.000000
Minibatch accuracy: 3.9%
Validation accuracy: 9.9%
Test accuracy: 10.2%

我哪里做错了?我发现准确率非常低。

最佳答案

  1. 据我所知,您正在使用 Tensorflow 构建一个简单的 2 层 FNN。虽然没关系,但你不会获得非常高的准确度。但如果你尝试,你需要仔细调整超参数 - 学习率、正则化强度、衰减率、隐藏层中的神经元数量。

  2. 您并未使用所有数据,因此肯定会降低预测的质量。它仍然可以工作,但您应该检查训练、验证和测试集中的类分布。某些类在某个数据集中的值可能太少。您至少需要对您的选择进行分层。

  3. 您确定您对深度学习有深入的了解吗?尝试一下 cs231n 类(class)可能是个好主意。

关于machine-learning - 在 CIFAR-10 数据集上使用深度网络进行分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45422562/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com