Python 无法使Dataset.filter()在model/official/resnet/resnet\u run\u loop.py文件中工作

Python 无法使Dataset.filter()在model/official/resnet/resnet\u run\u loop.py文件中工作,python,tensorflow,resnet,Python,Tensorflow,Resnet,在官方的resnet模型中,当eval_仅设置为True时,我希望通过'label'的值过滤test.bin中的数据集。我尝试使用tf.data.Dataset.filter()函数只获取一类测试数据,但没有成功 dataset = dataset.filter(lambda inputs, label: tf.equal(label,15)) 我将此代码放在resnet\u run\u loop.process\u record\u dataset函数中,但它引发了一个错误 raise V

在官方的resnet模型中,当eval_仅设置为True时,我希望通过'label'的值过滤test.bin中的数据集。我尝试使用tf.data.Dataset.filter()函数只获取一类测试数据,但没有成功

dataset = dataset.filter(lambda inputs, label: tf.equal(label,15))
我将此代码放在resnet\u run\u loop.process\u record\u dataset函数中,但它引发了一个错误

 raise ValueError("`predicate` must return a scalar boolean tensor.")

我发现张量'label'的形状是(?,):'tensor(“arg1:0”,shape=(?,),dtype=int32,device=/device:CPU:0)”

我在不同的情况下遇到了相同的问题,正如评论中所建议的,这个问题是由筛选前的批处理引起的

您可以使用以下示例再现此情况:

import pprint
import tensorflow as tf

dataset = tf.data.Dataset.zip((
    tf.data.Dataset.range(0, 5),
    tf.data.Dataset.from_tensor_slices([0, 10, 15, 20, 15])
))
pprint.pprint(list(dataset.as_numpy_iterator()))
# [(0, 0), (1, 10), (2, 15), (3, 20), (4, 15)]

filtered = dataset.filter(lambda x, y: y == 15)
pprint.pprint(list(filtered.as_numpy_iterator()))
# [(2, 15), (4, 15)]

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.filter(lambda x, y: y == 15)
# ValueError: `predicate` return type must be convertible to a scalar boolean tensor. Was [...]
此问题的一个简单解决方案是对数据集进行筛选,然后再次进行批处理:

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.unbatch().filter(lambda x, y: y == 15).batch(BATCH_SIZE)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
#  (array([4]), array([15], dtype=int32))]
如果您不知道或不想跟踪
BATCH\u SIZE
的值,可以根据需要调整计算批次大小

我最终将这两种解决方案结合在一起,如下所示:

def calculate_batch_size(dataset):
    return next(iter(dataset))[0].shape[0]

def filter_batch(dataset, pred_fn):
    batch_size = calculate_batch_size(dataset)
    return dataset.unbatch().filter(pred_fn).batch(batch_size)

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = filter_batch(batched, lambda x, y: y == 15)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
#  (array([4]), array([15], dtype=int32))]

输入是图像的32*32*3数据,“标签”是类别号。15代表“瓶子”。顺便说一下,我使用的是cifar100数据集。传递给filter的函数应该返回一个标量,如错误消息所示,但是
标签的形状未指定,因此
tf.equal()
的输出形状未指定。如果在调用
filter
之前对数据集进行批处理,则
label
的大小实际上将大于1。确保在批处理之前进行筛选,以便可以安全地将
标签的形状指定为
[]
(=标量)。此外,您使用的是什么版本的TF?我看到1.10.0中的
过滤器
出现了一个问题,在1.8.0中,合法标量工作正常,因此出现了这个错误。不过,这可能无关紧要。