Tensorflow 如何使用tf.dataset进行预测

Tensorflow 如何使用tf.dataset进行预测,tensorflow,keras,Tensorflow,Keras,我对tf.dataset和tf.keras.predict()有问题。我不知道为什么predict()的输出数组的长度大于所用数据的原始长度。这是一张草图: 在我使用数组之前。如果我对lenght x数组应用predict(),我会得到lenght x的输出。。。这是我预期的行为 我有一些长度(10000)的测试数据的csv。现在我用 LABEL_COLUMN = 'label' LABELS = [0, 1] def get_dataset(file_path, **kwargs): d

我对tf.dataset和tf.keras.predict()有问题。我不知道为什么predict()的输出数组的长度大于所用数据的原始长度。这是一张草图: 在我使用数组之前。如果我对lenght x数组应用predict(),我会得到lenght x的输出。。。这是我预期的行为

我有一些长度(10000)的测试数据的csv。现在我用

LABEL_COLUMN = 'label'
LABELS = [0, 1]

def get_dataset(file_path, **kwargs):
  dataset = tf.data.experimental.make_csv_dataset(
      file_path,
      batch_size=1, # Artificially small to make examples easier to show.
      label_name=LABEL_COLUMN,
      na_value="?",
      num_epochs=1,
      ignore_errors=True,
      **kwargs)
  return dataset
将其转换为tf.dataset

val='data/test.csv'
val_data= get_dataset(val)
现在使用

scores=bert_model.predict(val_data)
提供一个数组输出,它比原始csv文件(10000)的输出大很多

我真的走了。我还问自己,keras如何知道tf.dataset的哪些“键”用于Predictrion

1的结构。数据集的元素看起来像“val[0]”:

({'input_id':,'token_type_id':,'attention_mask':},)
  • 为什么我的标签列没有名为“label”的键?前3个键都有它们的名称,模型使用这3列进行训练
  • 我使用上面的结构和标签列作为预测的输入


    有什么想法吗?这是因为从csv生成数据集的功能吗?

    您是否检查了数据集的大小,而不是csv文件,以查看它是否真的包含10000个样本?我该怎么做?由于我使用从csv文件生成,因此它返回一个平面数据集,其中基数为-2如果启用了渴望模式,则
    len(dataset)
    ?我使用tf2。是否需要启用“急切模式额外”功能?如果尚未禁用,则不需要。默认情况下,它将被启用。
    ({'input_ids': <tf.Tensor: shape=(15,), dtype=int32, numpy=
    array([    3,  2019,   479,  1169,  4013, 26918,   259,     4, 14576,
            3984,   889,   648,  1610, 26918,     4])>, 'token_type_ids': <tf.Tensor: shape=(15,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])>, 'attention_mask': <tf.Tensor: shape=(15,), dtype=int32, numpy=array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])>}, <tf.Tensor: shape=(), dtype=int64, numpy=0>)