Python TensorFlow Federated:如何为具有多个输入的模型编写输入规范

Python TensorFlow Federated:如何为具有多个输入的模型编写输入规范,python,tensorflow,tensorflow-federated,Python,Tensorflow,Tensorflow Federated,我试图使用tensorflow提供的联邦学习库制作一个图像字幕模型,但我被这个错误卡住了 密集层的输入0与层不兼容::预期最小值ndim=2,发现ndim=1。 这是我的输入规格: input_spec=collections.OrderedDict(x=(tf.TensorSpec(shape=(2048,), dtype=tf.float32), tf.TensorSpec(shape=(34,), dtype=tf.int32)), y=tf.TensorSpec(shape=(None)

我试图使用tensorflow提供的联邦学习库制作一个图像字幕模型,但我被这个错误卡住了

密集层的输入0与层不兼容::预期最小值ndim=2,发现ndim=1。

这是我的输入规格:

input_spec=collections.OrderedDict(x=(tf.TensorSpec(shape=(2048,), dtype=tf.float32), tf.TensorSpec(shape=(34,), dtype=tf.int32)), y=tf.TensorSpec(shape=(None), dtype=tf.int32))

该模型将图像特征作为第一个输入,将词汇表列表作为第二个输入,但我无法在input_spec变量中表达这一点。我试着把它表达成一个列表,但还是不起作用。下一步我可以试试什么?

好问题!在我看来,这个错误来自TensorFlow本身——这表明您可能拥有正确的嵌套结构,但叶子可能已经脱落。从TFF的角度来看,您的输入规范似乎“应该可以工作”,所以它可能与您拥有的数据有点不匹配

我要尝试的第一件事是——如果您有一个示例
tf.data.Dataset
,它将被传递到您的客户机计算中,您可以简单地将
input\u spec
作为
element\u spec
属性直接从该数据集读取。这看起来像:

# ds = example dataset
input_spec = ds.element_spec
这是最简单的方法。如果您有类似于“numpy数组列表”的内容,那么仍然有一种方法可以从数据本身中提取这些信息——下面的代码片段应该可以帮助您:

# data = list of list of numpy arrays
input_spec = tf.nest.map_structure(lambda x: tf.TensorSpec(x.shape, x.dtype), data)
最后,如果您有一个
tf.Tensors
列表,TensorFlow提供了一个类似的功能:

# tensor_structure = list of lists of tensors
tf.nest.map_structure(tf.TensorSpec.from_tensor, tensor_structure)

简而言之,我不会手动指定
input\u spec
,而是让数据告诉您它的输入规范应该是什么。

我尝试了您的第一个建议,它成功了!非常感谢你!!