Tensorflow 如何在tf2.0上同时进行多个回迁?

Tensorflow 如何在tf2.0上同时进行多个回迁?,tensorflow,tensorflow2.0,Tensorflow,Tensorflow2.0,在tf1中,我可以定义所需获取和使用的列表 sess.run(myList, feed_dict) 获取tf1通过图形同时计算的列表的所有元素。如何在tf2.0中执行此操作 tf1中的示例代码: import tensorflow as tf a = [None]*5 for i in range(5): a[i] = tf.Variable(tf.random.normal([3,3])) fetch_list = [None]*5 for i in range(5): fe

在tf1中,我可以定义所需获取和使用的列表

sess.run(myList, feed_dict)
获取tf1通过图形同时计算的列表的所有元素。如何在tf2.0中执行此操作

tf1中的示例代码:

import tensorflow as tf
a = [None]*5
for i in range(5):
    a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
    fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(fetch_list)

我没有检查上面的代码是否运行,但我希望您能理解这一点。谢谢,因为默认情况下tf 2.x执行得很快,您只需执行以下操作:

然后像以前一样填充
fetch_list

根据你的实词例子的复杂性,你也可以考虑使用<代码> @ tf.Fult,它在推送数据之前建立一个执行图,以帮助类似于TF1的优化(这是一个巨大的过度简化,但你明白了)。p> <>你可以考虑对代码进行简单化/再修改,以促进这一点。可能最好只使用张量,而不是张量列表。很难确切地建议如何实现这一点,因为我不知道您为您的示例简化了什么

例如,如果我们认为你的代码> FutChyList是<代码>(5,3,3)< /代码>张量,而不是一个5代码>(3,3)张量的列表,那么我相信你会意识到你的简化示例代码(或多或少)归结为这样的事情:

@tf.function
def get_list(n):
  return tf.random.normal((n,3,3))

fetch_list = get_list(5)

上面的代码将串行执行它。您建议使用@tf.function使其同时运行。用tf.add(tf.gather,…)替换的完整操作的矢量化非常困难。在我的例子中,坚持列表要简单得多。但是,在您之前提供的解决方案中,我遇到了一些非常奇怪的问题。我的内存使用量不断增长,直到超出界限,就像我们在tf1中不断添加图形元素一样。
@tf.function
def get_list(n):
  return tf.random.normal((n,3,3))

fetch_list = get_list(5)