Keras model.fit()生成;“非类型对象不可调用”;错误,有调试提示吗?

Keras model.fit()生成;“非类型对象不可调用”;错误,有调试提示吗?,keras,tensorflow2.0,Keras,Tensorflow2.0,这是我看到的错误,使用transformer chatbot教程 我将模型子类化类(位置编码、多头注意层)转换为函数,以便将模型保存为h5文件。 我使用以下代码将“位置编码、自定义训练、多头注意层”类更改为函数: 为了检查代码是否正确转换,而不仅仅是共享多头注意层 #MULTI HEADED ATTENTION LAYER def split_heads(inputs, batch_size, num_heads, d_model): depth = d_model // num_head

这是我看到的错误,使用transformer chatbot教程 我将模型子类化类(位置编码、多头注意层)转换为函数,以便将模型保存为h5文件。 我使用以下代码将“位置编码、自定义训练、多头注意层”类更改为函数: 为了检查代码是否正确转换,而不仅仅是共享多头注意层


#MULTI HEADED ATTENTION LAYER
def split_heads(inputs, batch_size, num_heads, d_model):
  depth = d_model // num_heads
  inputs = tf.reshape(inputs, shape=(batch_size, -1, num_heads, depth))
  return tf.transpose(inputs, perm=[0, 2, 1, 3])


def call(d_model, num_heads, inputs):
  query, key, value, mask = inputs['query'], inputs['key'], inputs[
        'value'], inputs['mask']
  batch_size = tf.shape(query)[0]
  depth = d_model // num_heads

    # linear layers
  query = tf.keras.layers.Dense(units=d_model)(query)
  key = tf.keras.layers.Dense(units=d_model)(key)
  value = tf.keras.layers.Dense(units=d_model)(value)

    # split heads
  query = split_heads(query, batch_size, num_heads, d_model)
  key = split_heads(key, batch_size, num_heads, d_model)
  value = split_heads(value, batch_size, num_heads, d_model)

    # scaled dot-product attention
  scaled_attention = scaled_dot_product_attention(query, key, value, mask)

  scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

    # concatenation of heads
  concat_attention = tf.reshape(scaled_attention,
                                  (batch_size, -1, d_model))

  dense = tf.keras.layers.Dense(units=d_model)

    # final linear layer
  outputs = dense(concat_attention)

  return outputs

它相当于上面提到的transformer chatbot教程链接中提到的代码。 我最终需要移动android设备上的transformer聊天机器人机制