Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/9/git/25.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何在Tensorflow中显示Dataset对象中的类分布_Tensorflow_Dataset_Multilabel Classification - Fatal编程技术网

如何在Tensorflow中显示Dataset对象中的类分布

如何在Tensorflow中显示Dataset对象中的类分布,tensorflow,dataset,multilabel-classification,Tensorflow,Dataset,Multilabel Classification,我正在使用自己的图像进行多类分类任务 filenames = [] # a list of filenames labels = [] # a list of labels corresponding to the filenames full_ds = tf.data.Dataset.from_tensor_slices((filenames, labels)) 此完整数据集将被洗牌并拆分为训练数据集、有效数据集和测试数据集 full_ds_size = len(filenames) ful

我正在使用自己的图像进行多类分类任务

filenames = [] # a list of filenames
labels = [] # a list of labels corresponding to the filenames
full_ds = tf.data.Dataset.from_tensor_slices((filenames, labels))
此完整数据集将被洗牌并拆分为训练数据集、有效数据集和测试数据集

full_ds_size = len(filenames)
full_ds = full_ds.shuffle(buffer_size=full_ds_size*2, seed=128) # seed is used for reproducibility

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
现在我正在努力理解每个类是如何分布在训练、有效和测试中的。一个丑陋的解决方案是迭代数据集中的所有元素并计算每个类的出现次数。有没有更好的办法解决这个问题

我丑陋的解决方案:

def get_class_distribution(dataset):
    class_distribution = {}
    for element in dataset.as_numpy_iterator():
        label = element[1]

        if label in class_distribution.keys():
            class_distribution[label] += 1
        else:
            class_distribution[label] = 0

    # sort dict by key
    class_distribution = collections.OrderedDict(sorted(class_distribution.items())) 
    return class_distribution


train_ds_class_dist = get_class_distribution(train_ds)
valid_ds_class_dist = get_class_distribution(valid_ds)
test_ds_class_dist = get_class_distribution(test_ds)

print(train_ds_class_dist)
print(valid_ds_class_dist)
print(test_ds_class_dist)

下面的答案假设:

  • 有五节课
  • 标签是0到4之间的整数
它可以根据您的需要进行修改

定义计数器函数:

def count_class(计数、批次、num_class=5):
标签=批次['label']
对于范围内的i(num_类):
cc=tf.cast(标签==i,tf.int32)
计数[i]+=tf.reduce_和(cc)
返回计数
使用
reduce
操作:

initial_state=dict((i,0)表示范围(5)内的i)
计数=列车减少(初始状态=初始状态,
reduce_func=计数(类)
打印([(k,v.numpy())表示k,v在counts.items()中)