Python 无法将自定义模型保存为.pb格式(tensorflow 2.1.0)

Python 无法将自定义模型保存为.pb格式(tensorflow 2.1.0),python,tensorflow,keras,lstm,Python,Tensorflow,Keras,Lstm,我正在尝试将带有自定义层的Keras模型(.h5/json)转换为tensorflow模型(.pb),并使用以下自定义LSTM层 class AttentionVLSTMCell(AbstractRNNCell): def __init__(...): super(AttentionVLSTMCell, self).__init__(**kwargs) ... @property def state_size(self): return [self.times

我正在尝试将带有自定义层的Keras模型(.h5/json)转换为tensorflow模型(.pb),并使用以下自定义LSTM层

class AttentionVLSTMCell(AbstractRNNCell):

  def __init__(...):

    super(AttentionVLSTMCell, self).__init__(**kwargs)

...

  @property
  def state_size(self):
    return [self.timestep*self.att_input_dim, self.units]

  def build(self, input_shape): # difinition of the weights

...

  @tf.function
  def _time_distributed_dense(self, x, w, b=None, dropout=None,
                              input_dim=None, output_dim=None,
                              timesteps=None, training=None):
    if input_dim is None:
        input_dim = K.shape(x)[2]
    if timesteps is None:
        timesteps = K.shape(x)[1]
    if output_dim is None:
        output_dim = K.shape(w)[1]

    if dropout is not None and 0. < dropout < 1.:
        # apply the same dropout pattern at every timestep
        ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
        dropout_matrix = K.dropout(ones, dropout)
        expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
        x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)

    # collapse time dimension and batch dimension together
    x = K.reshape(x, (-1, input_dim))
    x = K.dot(x, w)
    if b is not None:
        x = K.bias_add(x, b)
    # reshape to 3D tensor
    if K.backend() == 'tensorflow':
        x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
        x.set_shape([None, None, output_dim])
    else:
        x = K.reshape(x, (-1, timesteps, output_dim))
    return x

  def _compute_update_vertex(self, x, V_tm1, c):
    """Computes carry and output using split kernels."""
    # x = W * track
    x_i, x_f, x_c, x_o = x
    V_tm1_i, V_tm1_f, V_tm1_c, V_tm1_o, V_tm1_o2, V_tm1_u, V_tm1_v = V_tm1
    c_i, c_f, c_c, c_o = c
    # i = x_i + V_tm1_i * R_i
    #   = W_i * track + V_tm1_i * R_i + C_i * context
    
    print(x_i)
    print(V_tm1_i)
    print(c_i)
    i = self.recurrent_activation(
        x_i 
        + K.dot(V_tm1_i, self.recurrent_kernel[:, :self.units])
        + K.dot(c_i, self.context_kernel[:, :self.units]))

    f = self.recurrent_activation(
        x_f 
        + K.dot(V_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])
        + K.dot(c_f, self.context_kernel[:, self.units:self.units * 2]))

    c = self.activation(
        x_c 
        + K.dot(V_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])
        + K.dot(c_c, self.context_kernel[:, self.units * 2:self.units * 3]))
    
    # U = update vertex
    U = f * V_tm1_u + i * c

    o = self.recurrent_activation(
        x_o 
        + K.dot(V_tm1_o, self.recurrent_kernel[:, self.units * 3:])
        + K.dot(c_o, self.context_kernel[:, self.units * 3:]))

    h_temp = o * self.activation(V_tm1_o2)
    # h size [self.units]
    h = self.dense_activation(K.dot(h_temp, self.dense_kernel))
    # h size [1] activated sigmoid
    
    #the size of h is [units]
    V = h * U + (1-h) * V_tm1_v
    return h, V

  def call(self, inputs, states, training=None):
    # store the whole sequence so we can "attend" to it at each timestep

    att = states[0] # Attention input (track num, input dim)
    self.x_seq = K.reshape(att, (-1, self.timestep, self.att_input_dim)) # Attention input (track num, input dim)
    V_tm1 = states[1] # previous Vertex state (units)

    # Additive Attention Bahdanau et al., 2015 
    # apply the a dense layer over the time dimension of the sequence
    # do it here because it doesn't depend on any previous steps
    # thefore we can save computation time:
    self._uxpb = self._time_distributed_dense(self.x_seq, 
                                              self.attention_kernel_U, 
                                              b=self.attention_kernel_b,
                                              input_dim=self.att_input_dim,
                                              timesteps=self.timestep,
                                              output_dim=self.units)

    # repeat the input track to the length of the sequence (track num, feature dim))
    _tt = K.repeat(inputs, self.timestep)
    _Wxtt = K.dot(_tt, self.attention_kernel_W)
    et = K.dot(activations.tanh(_Wxtt + self._uxpb), K.expand_dims(self.attention_kernel_V))
    """
    #Dot-Product Attention Luong et al., 2015 / Scaled Dot-Product Attention Vaswani 2017
    self.x_seq /= np.sqrt(self.att_input_dim)
    et = K.batch_dot(K.expand_dims(inputs), self.x_seq, axes=[1, 2])
    et = K.reshape(et, (-1, self.timestep, 1))
    """

    at = K.exp(et)
    at_sum = K.sum(at, axis=1)
    at_sum_repeated = K.repeat(at_sum, self.timestep)
    at /= at_sum_repeated  # attention weights ({batchsize}, track num, 1)
    context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1)
   

    if self.implementation == 1:
      # input = track
      inputs_i = inputs
      inputs_f = inputs
      inputs_c = inputs
      inputs_o = inputs
      # k = W
      k_i, k_f, k_c, k_o = array_ops.split(
              self.kernel, num_or_size_splits=4, axis=1)
      x_i = K.dot(inputs_i, k_i)
      x_f = K.dot(inputs_f, k_f)
      x_c = K.dot(inputs_c, k_c)
      x_o = K.dot(inputs_o, k_o)
      if self.use_bias:
        b_i, b_f, b_c, b_o = array_ops.split(
            self.bias, num_or_size_splits=4, axis=0)
        x_i = K.bias_add(x_i, b_i)
        x_f = K.bias_add(x_f, b_f)
        x_c = K.bias_add(x_c, b_c)
        x_o = K.bias_add(x_o, b_o)

      V_tm1_i = V_tm1
      V_tm1_f = V_tm1
      V_tm1_c = V_tm1
      V_tm1_o = V_tm1
      V_tm1_o2 = V_tm1
      V_tm1_u = V_tm1
      V_tm1_v = V_tm1

      c_i = context
      c_f = context
      c_c = context
      c_o = context

      x = (x_i, x_f, x_c, x_o)
      V_tm1 = (V_tm1_i, V_tm1_f, V_tm1_c, V_tm1_o, V_tm1_o2, V_tm1_u, V_tm1_v)
      c = (c_i, c_f, c_c, c_o)

      h, V = self._compute_update_vertex(x, V_tm1, c)

    return [h, at], [att, V]

  def get_config(self):

...

  def get_initial_state(self, inputs=True, batch_size=True, dtype=None):

...
V_tm1_i(注意LSTM输入门的计算)中存在问题:

我知道我们必须使用张量>=2 ndims来表示“MatMul”,但实际上它有一个批量大小维度。 因此,我希望它成为
[None,256],[256256]

自定义图层的输入和隐藏状态如下所示:

AttentionVLSTM
inputs
Tensor("model/Decoder_Attention_VLSTM/strided_slice_1:0", shape=(None, 256), dtype=float32)
states
(<tf.Tensor 'model/reshape/Reshape:0' shape=(None, 27136) dtype=float32>, <tf.Tensor 'model/Decoder_Activation_2/Relu:0' shape=(None, 256) dtype=float32>)
AttentionVLSTM
inputs
Tensor("TensorArrayV2Read/TensorListGetItem:0", shape=(None, 256), dtype=float32)
states
(<tf.Tensor 'Placeholder_3:0' shape=(None, 27136) dtype=float32>, <tf.Tensor 'Placeholder_4:0' shape=(None, 256) dtype=float32>)
AttentionVLSTM
inputs
Tensor("inputs:0", shape=(None, 256), dtype=float32)
states
(<tf.Tensor 'states:0' shape=(27136,) dtype=float32>, <tf.Tensor 'states_1:0' shape=(256,) dtype=float32>)

我可以解决这个问题。错误的原因是该自定义类末尾的初始状态定义。这是不必要的,因为初始状态已在“AbstractRNNCell”中定义

    i = self.recurrent_activation(
        x_i 
        + K.dot(V_tm1_i, self.recurrent_kernel[:, :self.units])
        + K.dot(c_i, self.context_kernel[:, :self.units]))
AttentionVLSTM
inputs
Tensor("model/Decoder_Attention_VLSTM/strided_slice_1:0", shape=(None, 256), dtype=float32)
states
(<tf.Tensor 'model/reshape/Reshape:0' shape=(None, 27136) dtype=float32>, <tf.Tensor 'model/Decoder_Activation_2/Relu:0' shape=(None, 256) dtype=float32>)
AttentionVLSTM
inputs
Tensor("TensorArrayV2Read/TensorListGetItem:0", shape=(None, 256), dtype=float32)
states
(<tf.Tensor 'Placeholder_3:0' shape=(None, 27136) dtype=float32>, <tf.Tensor 'Placeholder_4:0' shape=(None, 256) dtype=float32>)
AttentionVLSTM
inputs
Tensor("inputs:0", shape=(None, 256), dtype=float32)
states
(<tf.Tensor 'states:0' shape=(27136,) dtype=float32>, <tf.Tensor 'states_1:0' shape=(256,) dtype=float32>)
AttentionVLSTM
inputs
Tensor("model/Decoder_Attention_VLSTM/strided_slice_1:0", shape=(32, 256), dtype=float32)
states
(<tf.Tensor 'model/reshape/Reshape:0' shape=(32, 27136) dtype=float32>, <tf.Tensor 'model/Decoder_Activation_2/Relu:0' shape=(32, 256) dtype=float32>)
AttentionVLSTM
inputs
Tensor("TensorArrayV2Read/TensorListGetItem:0", shape=(32, 256), dtype=float32)
states
(<tf.Tensor 'Placeholder_3:0' shape=(32, 27136) dtype=float32>, <tf.Tensor 'Placeholder_4:0' shape=(32, 256) dtype=float32>)