Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/364.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python cifar10.load_data()下载数据需要很长时间_Python_Keras - Fatal编程技术网

Python cifar10.load_data()下载数据需要很长时间

Python cifar10.load_data()下载数据需要很长时间,python,keras,Python,Keras,嗨,我下载了cifar-10数据集 在我的代码中,它加载数据集,如下所示 import cv2 import numpy as np from keras.datasets import cifar10 from keras import backend as K from keras.utils import np_utils nb_train_samples = 3000 # 3000 training samples nb_valid_samples = 100 # 100 valid

嗨,我下载了cifar-10数据集

在我的代码中,它加载数据集,如下所示

import cv2
import numpy as np

from keras.datasets import cifar10
from keras import backend as K
from keras.utils import np_utils

nb_train_samples = 3000 # 3000 training samples
nb_valid_samples = 100 # 100 validation samples
num_classes = 10

def load_cifar10_data(img_rows, img_cols):

    # Load cifar10 training and validation sets
    (X_train, Y_train), (X_valid, Y_valid) = cifar10.load_data()

    # Resize trainging images
    if K.image_dim_ordering() == 'th':
        X_train = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_train[:nb_train_samples,:,:,:]])
        X_valid = np.array([cv2.resize(img.transpose(1,2,0), (img_rows,img_cols)).transpose(2,0,1) for img in X_valid[:nb_valid_samples,:,:,:]])
    else:
        X_train = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_train[:nb_train_samples,:,:,:]])
        X_valid = np.array([cv2.resize(img, (img_rows,img_cols)) for img in X_valid[:nb_valid_samples,:,:,:]])

    # Transform targets to keras compatible format
    Y_train = np_utils.to_categorical(Y_train[:nb_train_samples], num_classes)
    Y_valid = np_utils.to_categorical(Y_valid[:nb_valid_samples], num_classes)

    return X_train, Y_train, X_valid, Y_valid

但下载数据集需要很长时间。相反,我手动下载了“cifar-10-python.tar.gz”。那么我如何才能将其加载到变量中,(X\u-train,Y\u-train),(X\u-valid,Y\u-valid)而不是使用cifar10.load\u data()?

请原谅我的英语。我正在尝试手动加载cifar-10数据集。在下面的代码中,我将
cifar-10-python.tar.gz解包到一个文件夹中,并将文件
data\u batch\u 1
从文件夹加载到4个数组中:
x\u train
y\u train
x\u test
y\u test
。20%的
数据\u批次\u 1
用于验证
x\u测试
y\u测试
,其余用于培训
x\u列车
y\u列车

import pickle
import numpy
# load data
with open('cifar-10-batches-py\\data_batch_1','rb') as f:
    dict1 = pickle.load(f,encoding='bytes')

x = dict1[b'data']
x = x.reshape(len(x), 3, 32, 32).astype('float32')

y = numpy.asarray(dict1[b'labels'])

x_test = x[0:int(0.2 * x.shape[0]), :, :, :]
y_test = y[0:int(0.2 * y.shape[0])]
x_train = x[int(0.2 * x.shape[0]):x.shape[0], :, :, :]
y_train = y[int(0.2 * y.shape[0]):y.shape[0]]

这里的代码从相应的批处理文件中读取训练和测试图像,如中所述,修改自,并进行了很好的解释

import pickle
import numpy as np

for i in range(1,6):
    path = 'data_batch_' + str(i)
    with open(path, mode='rb') as file:
        # note the encoding type is 'latin1'
        batch = pickle.load(file, encoding='latin1')
    if i == 1:  
        x_train = (batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)).astype('float32')
        y_train = batch['labels']
    else:
        x_train_temp = (batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)).astype('float32')
        y_train_temp = batch['labels']
        x_train = np.concatenate((x_train,x_train_temp),axis = 0)
        y_train = np.concatenate((y_train,y_train_temp),axis=0)

path = 'test_batch'
with open(path,'rb') as file:
    # note the encoding type is 'latin1'
    batch = pickle.load(file, encoding='latin1')
    x_test = (batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)).astype('float32')
    y_test = batch['labels']
我们可以将读取的数据可视化如下:

import matplotlib.pyplot as plt

x_train=x_train.astype(np.uint8)
y_train = np.expand_dims(y_train, axis = 1)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(np.squeeze(x_train[i]), cmap=plt.cm.binary)
    # The CIFAR labels happen to be arrays, 
    # which is why you need the extra index
    plt.xlabel(class_names[y_train[i][0]])
plt.show()

此外,如果下载时间是您唯一的问题,您仍然可以使用
load_data()

小修正,data_batch1-data_batch5包含训练图像,“test_batch”包含测试图像。你不必像以前那样把阅读词典分成测试和训练两部分。资料来源:官方数据集网页