使用@tffunction的Tensorflow2警告

使用@tffunction的Tensorflow2警告,tensorflow,Tensorflow,此示例代码来自Tensorflow 2 writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function") @tf.function def my_func(step): with writer.as_default(): # other model code would go here tf.summary.scalar("my_metric", 0.5, step=step) for step in ra

此示例代码来自Tensorflow 2

writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")

@tf.function
def my_func(step):
  with writer.as_default():
    # other model code would go here
    tf.summary.scalar("my_metric", 0.5, step=step)

for step in range(100):
  my_func(step)
  writer.flush()
但它正在发出警告

警告:tensorflow:最近5次调用中有5次调用触发了tf.function Retracting。追踪是昂贵的 而过多的跟踪可能是由于传递了python 对象而不是张量。此外,tf.function还具有 experimental_relax_shapes=用于松弛参数形状的True选项 这样可以避免不必要的回溯。请参阅 还有更多 细节


有更好的方法吗?

tf。函数
有一些“特性”。我强烈建议阅读这篇文章:

在这种情况下,问题在于每次使用不同的输入签名进行调用时,函数都会被“回溯”(即构建一个新的图形)。对于张量,输入签名指的是形状和数据类型,但对于Python数字,每个新值都被解释为“不同”。在这种情况下,由于您使用一个每次都会更改的
step
变量调用该函数,因此该函数也会每次都被回溯。对于“真实”代码(例如,在函数内部调用模型),这将非常缓慢

您可以通过简单地将
步骤
转换为张量来修复此问题,在这种情况下,不同的值将不算作新的输入签名:

for step in range(100):
    step = tf.convert_to_tensor(step, dtype=tf.int64)
    my_func(step)
    writer.flush()
或者使用
tf.range
直接获取张量:

for step in tf.range(100):
    step = tf.cast(step, tf.int64)
    my_func(step)
    writer.flush()

这应该不会产生警告(而且速度要快得多)。

这个问题是否会导致代码执行速度变慢,而不使用
@tf.function
?否,因为这样就没有要构建/回溯的图形。
@tf.function
只采用张量而不是定标器类型。转换成张量很好。