Tensorflow 如何解决TPU推理的数据获取瓶颈?

Tensorflow 如何解决TPU推理的数据获取瓶颈?,tensorflow,google-compute-engine,tpu,google-cloud-tpu,Tensorflow,Google Compute Engine,Tpu,Google Cloud Tpu,这就是我的推理设置的样子 autotune = tf.data.experimental.AUTOTUNE with strategy.scope(): model = LoadModel() raw_dataset = tf.data.TFRecordDataset(tfRecordAddress) train_dataset = raw_dataset.map(_parse_example, num_parallel_calls=autotune) trai

这就是我的推理设置的样子

autotune = tf.data.experimental.AUTOTUNE

with strategy.scope():
    model = LoadModel()
    raw_dataset = tf.data.TFRecordDataset(tfRecordAddress)
    train_dataset = raw_dataset.map(_parse_example, num_parallel_calls=autotune)
    train_dataset = train_dataset.padded_batch(batch_size, padding_values=(1, 1, b'-'), padded_shapes=(512, 512, 1))
    # train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.prefetch(autotune)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)

def per_core_inference_fn(inputIds,attnIds ):
    return model.inference((inputIds, attnIds))

@tf.function
def inference_fn(inputIds, attnIds):
    return strategy.run(per_core_inference_fn, args=(inputIds,attnIds))

results = []
for x in train_dataset:
    t0 = time.time()
    results.append(inference_fn(x[0], x[1]))
    t1 = time.time()
    print('time is :', t1-t0)
由于批量很大,推断速度非常快,大约为.0003秒。但是,获取下一批数据需要很长时间,
对于train_dataset:
中的x,大约需要60-80秒

据我所知,我的推断是正确的,但不知何故,TPU的CPU在批量检索方面遇到了巨大的瓶颈


我在训练中没有看到这个瓶颈。因此,它看起来像是
model.fit
正在做一些我没有做的事情。

我有一种感觉,这一瓶颈的出现是由于x-in-train-u数据集的
。批量加载之间的60-80秒对我来说意味着预取没有按预期工作。在自定义训练循环(CTL)代码中,我通常会看到整个循环被包装在一个
tf.function
中,例如In

你能用同样的方法修改你的代码吗?您还可以尝试捕获TPU配置文件(),而不是使用
time.time()
进行基准测试