Python keras argmax没有渐变。如何为argmax定义渐变?

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

我的模型是:


latent_dim = 512
encoder_inputs = Input(shape=(train_data.shape[1],))
encoder_dense = Dense(vocabulary, activation='softmax')
encoder_outputs = Embedding(vocabulary, latent_dim)(encoder_inputs)
encoder_outputs = LSTM(latent_dim, return_sequences=True)(encoder_outputs)
encoder_outputs = Dropout(0.5)(encoder_outputs)
encoder_outputs = encoder_dense(encoder_outputs)
encoder_outputs = Lambda(K.argmax, arguments={'axis':-1})(encoder_outputs)
encoder_outputs = Lambda(K.cast, arguments={'dtype':'float32'})(encoder_outputs)

encoder_dense1 = Dense(train_label.shape[1], activation='softmax')
decoder_embedding = Embedding(vocabulary, latent_dim)
decoder_lstm1 = LSTM(latent_dim, return_sequences=True)
decoder_lstm2 = LSTM(latent_dim, return_sequences=True)
decoder_dense2 = Dense(vocabulary, activation='softmax')

decoder_outputs = encoder_dense1(encoder_outputs)
decoder_outputs = decoder_embedding(decoder_outputs)
decoder_outputs = decoder_lstm1(decoder_outputs)
decoder_outputs = decoder_lstm2(decoder_outputs)
decoder_outputs = Dropout(0.5)(decoder_outputs)
decoder_outputs = decoder_dense2(decoder_outputs)
model = Model(encoder_inputs, decoder_outputs)

Layer (type)                 Output Shape              Param #   
input_7 (InputLayer)         (None, 32)                0         
embedding_13 (Embedding)     (None, 32, 512)           2018816   
lstm_19 (LSTM)               (None, 32, 512)           2099200   
dropout_10 (Dropout)         (None, 32, 512)           0         
dense_19 (Dense)             (None, 32, 3943)          2022759   
lambda_5 (Lambda)            (None, 32)                0         
lambda_6 (Lambda)            (None, 32)                0         
dense_20 (Dense)             (None, 501)               16533     
embedding_14 (Embedding)     (None, 501, 512)          2018816   
lstm_20 (LSTM)               (None, 501, 512)          2099200   
lstm_21 (LSTM)               (None, 501, 512)          2099200   
dropout_11 (Dropout)         (None, 501, 512)          0         
dense_21 (Dense)             (None, 501, 3943)         2022759   
Total params: 14,397,283
Trainable params: 14,397,283
Non-trainable params: 0


