Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何为TensorFlow数据集获取每个类的样本_Tensorflow_Tensorflow2.0_Tensorflow Datasets - Fatal编程技术网

如何为TensorFlow数据集获取每个类的样本

如何为TensorFlow数据集获取每个类的样本,tensorflow,tensorflow2.0,tensorflow-datasets,Tensorflow,Tensorflow2.0,Tensorflow Datasets,我正在使用来自TensorFlow数据集的数据集。 是否有一种简单的方法可以访问数据集中每个类的样本数?我在keras api中搜索,没有找到任何现成的函数 最后,我想用Y轴上的样本数绘制一个条形图,int表示X轴上的类id。目标是显示数据在类中的分布是多么均匀。使用np。从iter可以从iterable对象创建一维数组 import tensorflow_datasets as tfds import numpy as np import seaborn as sns dataset = t

我正在使用来自TensorFlow数据集的数据集。 是否有一种简单的方法可以访问数据集中每个类的样本数?我在keras api中搜索,没有找到任何现成的函数


最后,我想用Y轴上的样本数绘制一个条形图,int表示X轴上的类id。目标是显示数据在类中的分布是多么均匀。

使用
np。从iter
可以从iterable对象创建一维数组

import tensorflow_datasets as tfds
import numpy as np
import seaborn as sns

dataset = tfds.load('cifar10', split='train', as_supervised=True)

labels, counts = np.unique(np.fromiter(dataset.map(lambda x, y: y), np.int32), 
                       return_counts=True)

plt.ylabel('Counts')
plt.xlabel('Labels')
sns.barplot(x = labels, y = counts)