根据模型预测过滤Tensorflow数据集

根据模型预测过滤Tensorflow数据集,tensorflow,tensorflow2.0,tensorflow-datasets,Tensorflow,Tensorflow2.0,Tensorflow Datasets,我想通过只选择经过训练的模型正确预测的样本来过滤TensorFlow数据集。数据集由图像、标签对组成。我试过这个: predicted_labels = np.argmax(model.predict(dataset), axis=1) predicted_labels_dataset = tf.data.Dataset.from_tensor_slices(predicted_labels) zipped_dataset = tf.data.Dataset.zip((dataset, pred

我想通过只选择经过训练的模型正确预测的样本来过滤TensorFlow数据集。数据集由图像、标签对组成。我试过这个:

predicted_labels = np.argmax(model.predict(dataset), axis=1)
predicted_labels_dataset = tf.data.Dataset.from_tensor_slices(predicted_labels)
zipped_dataset = tf.data.Dataset.zip((dataset, predicted_labels_dataset))
correct_dataset = zipped_dataset.filter(lambda sample, predicted_label: tf.math.equal(x = predicted_label, y = sample[1]))

但它并没有像预期的那样起作用。我仍然在
正确的\u数据集中得到错误分类的样本。提前感谢您的帮助

我认为这不起作用的唯一原因是,您的
示例[1]
不代表标签(或格式不正确)。我运行了下面的玩具示例,效果很好。很高兴根据您的后续行动完善答案,直到我们确定确切的问题

import numpy as np
import tensorflow as tf


predicted_labels_dataset = tf.data.Dataset.from_tensor_slices([0,1,2,1,2,3])

inp_ds = tf.data.Dataset.from_tensor_slices(tf.random.normal(shape=[6,5]))
lbl_ds = tf.data.Dataset.from_tensor_slices([1,1,0,0,2,3])
dataset = tf.data.Dataset.zip((inp_ds, lbl_ds))
zipped_dataset = tf.data.Dataset.zip((dataset, predicted_labels_dataset))

correct_dataset = zipped_dataset.filter(lambda sample, predicted_label: tf.math.equal(x = predicted_label, y = sample[1]))

谢谢你的回复!我认为问题实际上在于keras模式。predict()。似乎在批处理上运行它会返回与在没有批处理的情况下运行它不同的标签。你有什么建议吗?