Warning: file_get_contents(/data/phpspider/zhask/data//catemap/1/list/4.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 有没有办法从CIFAR-10培训数据集中提取所需的类?_Python_List_Tensorflow_Arraylist_Conv Neural Network - Fatal编程技术网

Python 有没有办法从CIFAR-10培训数据集中提取所需的类?

Python 有没有办法从CIFAR-10培训数据集中提取所需的类?,python,list,tensorflow,arraylist,conv-neural-network,Python,List,Tensorflow,Arraylist,Conv Neural Network,我想做的看起来很简单,但就是不起作用。我想对每一类图像(矩阵)执行特定的操作,所以我首先必须从加扰的批次中提取它们中的每一个 from tensorflow.keras import datasets import numpy as np (train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data() print(len(train_images)) print(len(train_i

我想做的看起来很简单,但就是不起作用。我想对每一类图像(矩阵)执行特定的操作,所以我首先必须从加扰的批次中提取它们中的每一个

from tensorflow.keras import datasets
import numpy as np

(train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data()
print(len(train_images))
print(len(train_images))
train_images[train_labels==6]
这就是错误。当然,这是因为图像矩阵的形状(50000,32,32,3)。尽管图像和标签的长度相同,但python无法以某种方式使用矩阵作为1项进行过滤。欢迎帮助

50000
50000


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-170-029cc3d4f0a9> in <module>
      5 
      6 
----> 7 train_images[train_labels==6]

IndexError: boolean index did not match indexed array along dimension 1; dimension is 32 but corresponding boolean dimension is 1
50000
50000
---------------------------------------------------------------------------
索引器回溯(最后一次最近调用)
在里面
5.
6.
---->7列车图像[列车标签==6]
索引器错误:布尔索引与维度1上的索引数组不匹配;维度为32,但相应的布尔维度为1

这里的问题是,train_标签具有形状(50000,1),因此当您对其进行索引时,numpy尝试将其用作二维。这里有一个简单的解决方案

from tensorflow.keras import datasets
import numpy as np

(train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data()
print('Images Shape: {}'.format(train_images.shape))
print('Labels Shape: {}'.format(train_labels.shape))
idx = (train_labels == 6).reshape(train_images.shape[0])
print('Index Shape: {}'.format(idx.shape))
filtered_images = train_images[idx]
print('Filtered Images Shape: {}'.format(filtered_images.shape))