gpt4 book ai didi

python - cifar10.load_data()下载数据需要很长时间

转载 作者:太空宇宙 更新时间:2023-11-04 02:31:39 25 4
gpt4 key购买 nike

您好,我下载了 cifar-10 数据集。

在我的代码中,它加载数据集如下。

import cv2
import numpy as np

from keras.datasets import cifar10
from keras import backend as K
from keras.utils import np_utils

nb_train_samples = 3000 # 3000 training samples
nb_valid_samples = 100 # 100 validation samples
num_classes = 10

def load_cifar10_data(img_rows, img_cols):

# Load cifar10 training and validation sets
(X_train, Y_train), (X_valid, Y_valid) = cifar10.load_data()

# Resize trainging images
if K.image_dim_ordering() == 'th':
X_train = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_train[:nb_train_samples,:,:,:]])
X_valid = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_valid[:nb_valid_samples,:,:,:]])
else:
X_train = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_train[:nb_train_samples,:,:,:]])
X_valid = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_valid[:nb_valid_samples,:,:,:]])

# Transform targets to keras compatible format
Y_train = np_utils.to_categorical(Y_train[:nb_train_samples], num_classes)
Y_valid = np_utils.to_categorical(Y_valid[:nb_valid_samples], num_classes)

return X_train, Y_train, X_valid, Y_valid

但这需要很长时间才能下载数据集。相反,我手动下载了“cifar-10-python.tar.gz”。那么我如何将它加载到变量中,(X_train, Y_train), (X_valid, Y_valid) 而不是使用 cifar10.load_data()?

最佳答案

请原谅我的英语。我也在尝试手动加载 cifar-10 数据集。在下面的代码中,我将 cifar-10-python.tar.gz 解压到一个文件夹,并将文件夹中的文件 data_batch_1 加载到 4 个数组中:x_trainy_trainx_testy_testdata_batch_1 的 20% 作为 x_testy_test 用于验证,其余用于训练作为 x_trainy_train.

import pickle
import numpy
# load data
with open('cifar-10-batches-py\\data_batch_1','rb') as f:
dict1 = pickle.load(f,encoding='bytes')

x = dict1[b'data']
x = x.reshape(len(x), 3, 32, 32).astype('float32')

y = numpy.asarray(dict1[b'labels'])

x_test = x[0:int(0.2 * x.shape[0]), :, :, :]
y_test = y[0:int(0.2 * y.shape[0])]
x_train = x[int(0.2 * x.shape[0]):x.shape[0], :, :, :]
y_train = y[int(0.2 * y.shape[0]):y.shape[0]]

关于python - cifar10.load_data()下载数据需要很长时间,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49045172/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com