Python 保持示例索引与tf.keras.predict和tf.data.Dataset的对应关系

Python 保持示例索引与tf.keras.predict和tf.data.Dataset的对应关系,python,tensorflow,keras,Python,Tensorflow,Keras,我正在TensorFlow2中使用tf.keras API。我有100000个左右的图像保存为TFRecords(每个记录128个图像)。每个记录都有一个输入图像、目标图像和帧索引。我找不到一个干净的方法来保持帧索引和预测 以下是一个示例,但我使用NumPy数组构建数据集,而不是从TFRecords读取数据: import tensorflow as tf from tensorflow import keras import numpy as np # build dummy tf.data

我正在TensorFlow2中使用tf.keras API。我有100000个左右的图像保存为TFRecords(每个记录128个图像)。每个记录都有一个输入图像、目标图像和帧索引。我找不到一个干净的方法来保持帧索引和预测

以下是一个示例,但我使用NumPy数组构建数据集,而不是从TFRecords读取数据:

import tensorflow as tf
from tensorflow import keras
import numpy as np

# build dummy tf.data.Dataset
x = np.random.random(10000).astype(np.float32)
y = x + np.random.random(10000).astype(np.float32) * 0.1
idx = np.arange(10000, dtype=np.uint16)
np.random.shuffle(idx)  # frames are random in my TFRecord files
ds = tf.data.Dataset.from_tensor_slices((x, y, idx))
# pretend ds returned from TFRecord
ds = ds.map(lambda f0, f1, f2: (f0, f1))  # strip off idx
ds = ds.batch(32)

# build and train model
x = keras.Input(shape=(1,))
y_hat = keras.layers.Dense(1)(x)  # i.e. linear regression
model = keras.Model(x, y_hat)
model.compile('sgd', 'mse')
history = model.fit(ds, epochs=5)

# predict 1 batch
model.predict(ds, steps=1)

除了再次读取数据集以提取索引(容易出错)之外,是否有一种干净的方法来保持预测与图像索引的对应关系?在TF1.x中,这很简单。但我想利用TF2中干净的Keras compile()、fit()、predict()API。

好的,我想得太多了,实际上很简单。在进行预测时,只需将索引添加到数据集,并在迭代批处理时拉出索引:

rt tensorflow as tf
from tensorflow import keras
import numpy as np

def build_dataset(mode):
    np.random.seed(1)
    x = np.random.random(10000).astype(np.float32)
    y = x + np.random.random(10000).astype(np.float32) * 0.1
    idx = np.arange(10000, dtype=np.uint16)
    if mode == 'train':
        ds = tf.data.Dataset.from_tensor_slices((x, y))
        ds = ds.shuffle(128)
    else:
        ds = tf.data.Dataset.from_tensor_slices((x, idx))
    ds = ds.batch(32)

    return ds

# build and train simple linear regression model
x_tf = keras.Input(shape=(1,))
yhat_tf = keras.layers.Dense(1)(x_tf)
model = keras.Model(x_tf, yhat_tf)
model.compile(optimizer='sgd', loss='mse')
ds = build_dataset('train')
history = model.fit(ds, epochs=5)

# predict 1 batch
ds = build_dataset('predict')
for batch in ds:
    x_tf, indices_tf = batch 
    yhat_np = model.predict(x_tf)
    break