Python Tensorflow中自定义静态张量的未知批量维保持

Python Tensorflow中自定义静态张量的未知批量维保持,python,tensorflow,keras,Python,Tensorflow,Keras,注意:我正在使用tensorflow 2.3.0、python 3.8.2和numpy 1.18.5(但不确定这是否重要) 我正在编写一个自定义层,它在内部存储形状(a,b)的不可训练张量N,其中a,b是已知值(该张量是在init期间创建的)。调用输入张量时,它会展平输入张量,展平存储的张量,并将两者连接在一起。不幸的是,我似乎不知道如何在连接过程中保留未知的批处理维度。下面是最简单的代码: import tensorflow as tf from tensorflow.keras.layers

注意:我正在使用tensorflow 2.3.0、python 3.8.2和numpy 1.18.5(但不确定这是否重要)

我正在编写一个自定义层,它在内部存储形状(a,b)的不可训练张量N,其中a,b是已知值(该张量是在init期间创建的)。调用输入张量时,它会展平输入张量,展平存储的张量,并将两者连接在一起。不幸的是,我似乎不知道如何在连接过程中保留未知的批处理维度。下面是最简单的代码:

import tensorflow as tf
from tensorflow.keras.layers import Layer, Flatten

class CustomLayer(Layer):
   def __init__(self, N):                     # N is a tensor of shape (a, b), where a, b > 1
      super(CustomLayer, self).__init__()
      self.N = self.add_weight(name="N", shape=N.shape, trainable=False, initializer=lambda *args, **kwargs: N)

      # correct me if I'm wrong in using this initializer approach, but for some reason, when I
      # just do self.N = N, this variable would disappear when I saved and loaded the model

   def build(self, input_shape):
      pass                                    # my reasoning is that all the necessary stuff is handled in init

   def call(self, input_tensor):
      input_flattened = Flatten()(input_tensor)
      N_flattened = Flatten()(self.N)
      return tf.concat((input_flattened, N_flattened), axis=-1)
我注意到的第一个问题是
flatte()(self.N)
将返回一个与原始
self.N
具有相同形状(a,b)的张量,因此,返回值的形状将为(a,num\u input\u tensor\u values+b)。我的理由是,第一个维度a被视为批量大小。我修改了
调用
函数:

   def call(self, input_tensor):
      input_flattened = Flatten()(input_tensor)
      N = tf.expand_dims(self.N, axis=0)       # N would now be shape (1, a, b)
      N_flattened = Flatten()(N)
      return tf.concat((input_flattened, N_flattened), axis=-1)

这将返回一个带有形状(1,num_input_vals+a*b)的张量,这很好,但现在批次维度永久为1,我在开始使用此层训练模型时意识到了这一点,它只适用于批次大小为1的情况。这在模型摘要中也很明显——如果我将这个层放在输入之后,然后添加其他层,那么输出张量的第一个维度类似于
None,1,1,1,1…
。是否有一种方法可以存储此内部张量并在
调用中使用它,同时保留可变的批量大小?(例如,当批大小为4时,相同的展平N的副本将连接到4个展平输入张量中的每个张量的末端。)

您必须拥有与输入中的样本数量相同的展平
N
向量,因为您要连接到每个样本。将其视为将行配对并连接它们。如果只有一个
N
向量,则只能连接一对向量。 要解决这个问题,您应该使用
tf.tile()
重复
N
,重复次数与批次中的样本数量相同

例如:

   def call(self, input_tensor):
      input_flattened = Flatten()(input_tensor) # input_flattened shape: (None, ..)
      N = tf.expand_dims(self.N, axis=0)        # N shape: (1, a, b)
      N_flattened = Flatten()(N)                # N_flattened shape: (1, a*b)
      N_tiled = tf.tile(N_flattened, [tf.shape(input_tensor)[0], 1]) # repeat along the first dim as many times, as there are samples and leave the second dim alone
      return tf.concat((input_flattened, N_tiled), axis=-1)

我不确定我是否理解,所以:你想在输入批次中的每个展平样本之后连接展平N张量吗?@AndreaAngeli yeah。抱歉搞混了-我会把问题修改得更清楚。你是个天才。我以前看过
tf.tile()
,但我认为这与此无关。我刚刚尝试了你的代码,它工作得很好,保留了未知的批处理维度。非常感谢你!