Python 如何在Tensorflow中按张量形状过滤数据集

Python 如何在Tensorflow中按张量形状过滤数据集,python,tensorflow,tensorflow2.0,data-processing,Python,Tensorflow,Tensorflow2.0,Data Processing,我已从tfds.load加载了一个数据集,并希望丢弃某些干扰正确训练/对我没有用处的图像(例如,图像太小) 似乎在任何地方都没有关于这个特定问题的信息,所以我选择了最适合的数据集,即.filter(谓词)。不幸的是,谓词的输入具有不确定的形状(None,None,3),并且如预期的那样引发了一个错误,“int”不能与“NoneType”进行比较 用tensorflow解决这个问题是可能的还是我不应该浪费时间 伪码 ds_train = tfds.load('name') ds_train = d

我已从tfds.load加载了一个数据集,并希望丢弃某些干扰正确训练/对我没有用处的图像(例如,图像太小)

似乎在任何地方都没有关于这个特定问题的信息,所以我选择了最适合的数据集,即.filter(谓词)。不幸的是,谓词的输入具有不确定的形状(None,None,3),并且如预期的那样引发了一个错误,“int”不能与“NoneType”进行比较

用tensorflow解决这个问题是可能的还是我不应该浪费时间

伪码

ds_train = tfds.load('name')
ds_train = ds_train.map(lambda ds: ds['image'])
ds_train = ds_train.filter(lambda image: image.shape[0] >= 256)

使用
tf.data.Dataset
编写代码时,应该使用
tf.shape(tensor)
而不是
tensor.shape
,因为
tf.data.Dataset
在图形模式下工作

引用以下文件:

tf.shape和Tensor.shape在渴望模式下应该相同。在tf.function或compat.v1上下文中,直到执行时才可能知道所有维度。因此,在为图形模式定义自定义层和模型时,更喜欢动态tf.shape(x)而不是静态x.shape


对于它所造成的巨大痛苦,我们没有想到会有这么简单的解决方案
ds_train = ds_train.filter(lambda image: tf.shape(image)[0] >= 256)