Tensorflow 如何从只输出数组列表的生成器开发tf.data对象?

Tensorflow 如何从只输出数组列表的生成器开发tf.data对象?,tensorflow,tensorflow-datasets,Tensorflow,Tensorflow Datasets,我试图开发一个tf.data对象,该对象生成一个数组列表,但是我得到了一个数据不匹配错误。这是我的尝试 def labelGen(): yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64) Labeldataset = tf.data.D

我试图开发一个tf.data对象,该对象生成一个数组列表,但是我得到了一个数据不匹配错误。这是我的尝试

def labelGen():
    yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)

Labeldataset = tf.data.Dataset.from_generator(
     labelGen, (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), ([], [], [], [], []) )

list(Labeldataset.take(1))
这就是我得到的错误

InvalidArgumentError:TypeError:
生成器
生成的元素与预期结构不匹配。预期的结构是(tf.int64,tf.int64,tf.int64,tf.int64,tf.int64),但生成的元素是(,)。 回溯(最近一次呼叫最后一次):


首先,.from_生成器代码中的项目数不匹配。 第二,应该在不使用()的情况下调用生成器。 下面是在TF2.1中测试的工作代码

def labelGen():
    yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)

Labeldataset = tf.data.Dataset.from_generator(
    labelGen, # without ()
    (tf.int64, tf.int64, tf.int64, tf.int64), # should match number of items
    (tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]))) # use tf.TensorShape

list(Labeldataset.take(1))
结果:

[(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>)]
[(,)]