Python 调用函数中的可变批量大小

Python 调用函数中的可变批量大小,python,tensorflow,keras,tensorflow2.0,Python,Tensorflow,Keras,Tensorflow2.0,我试图用TensorFlow 2实现一个注意力网络。因此,对于每幅图像,我只想看几眼,即图像的一小部分。为此,我从tensorflow.keras.models.Model中实现了一个子类,下面是其中的一个片段 class RecurrentAttentionModel(models.Model): # ... def call(self, inputs): l = tf.random.uniform((40,2,), minval=0, maxval=1) for _ i

我试图用TensorFlow 2实现一个注意力网络。因此,对于每幅图像,我只想看几眼,即图像的一小部分。为此,我从tensorflow.keras.models.Model中实现了一个子类,下面是其中的一个片段

class RecurrentAttentionModel(models.Model):
# ...

def call(self, inputs):

    l = tf.random.uniform((40,2,), minval=0, maxval=1)

    for _ in range(0, self.glimpses):
        glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True)

        # some other code...
        # update l to take a glimpse somewhere else


    return result           
现在,上面的代码工作和训练都很完美,但我的问题是,我有硬编码的40,我在数据集中定义的批量大小。我无法在call方法中读取/获取批处理大小,因为变量“inputs”的形式为
Tensor(“input\u 1\u 77:0”,shape=(None,250500,1),dtype=float32)
,其中批处理大小的
None
似乎是预期行为。 当我用以下代码初始化l时(没有批处理大小)

它抛出了这个错误

ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]

我完全理解,但我不知道如何根据批次大小实现初始值。

您可以使用
tf.shape
动态提取批次大小维度

l = tf.random.normal(tf.stack([tf.shape(inputs)[0], 2]), minval=0, maxval=1))
l = tf.random.normal(tf.stack([tf.shape(inputs)[0], 2]), minval=0, maxval=1))