Tensorflow 如何从只输出数组列表的生成器开发tf.data对象?
我试图开发一个tf.data对象,该对象生成一个数组列表,但是我得到了一个数据不匹配错误。这是我的尝试Tensorflow 如何从只输出数组列表的生成器开发tf.data对象?,tensorflow,tensorflow-datasets,Tensorflow,Tensorflow Datasets,我试图开发一个tf.data对象,该对象生成一个数组列表,但是我得到了一个数据不匹配错误。这是我的尝试 def labelGen(): yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64) Labeldataset = tf.data.D
def labelGen():
yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)
Labeldataset = tf.data.Dataset.from_generator(
labelGen, (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), ([], [], [], [], []) )
list(Labeldataset.take(1))
这就是我得到的错误
InvalidArgumentError:TypeError:生成器
生成的元素与预期结构不匹配。预期的结构是(tf.int64,tf.int64,tf.int64,tf.int64,tf.int64),但生成的元素是(,)。
回溯(最近一次呼叫最后一次):
首先,.from_生成器代码中的项目数不匹配。 第二,应该在不使用()的情况下调用生成器。 下面是在TF2.1中测试的工作代码
def labelGen():
yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)
Labeldataset = tf.data.Dataset.from_generator(
labelGen, # without ()
(tf.int64, tf.int64, tf.int64, tf.int64), # should match number of items
(tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]))) # use tf.TensorShape
list(Labeldataset.take(1))
结果:
[(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>)]
[(,)]