Nlp Keras model.predict-argmax始终输出0(seq2seq模型)

Nlp Keras model.predict-argmax始终输出0(seq2seq模型),nlp,deep-learning,keras,rnn,argmax,Nlp,Deep Learning,Keras,Rnn,Argmax,我正在使用拼写蜜蜂代码,并将其应用于类似的seq2seq任务。我正在努力处理argmax函数的预测和输出。出于某种原因,argmax的输出在任何情况下都只返回0 我更改了很多参数,选择了其他轴。。但似乎什么都不管用。如果有任何线索,我将不胜感激 我有 parms = {'verbose': 2} lstm_params = {} def get_rnn(return_sequences= True): return LSTM(dim, return_sequences=retur

我正在使用拼写蜜蜂代码,并将其应用于类似的seq2seq任务。我正在努力处理argmax函数的预测和输出。出于某种原因,argmax的输出在任何情况下都只返回0

我更改了很多参数,选择了其他轴。。但似乎什么都不管用。如果有任何线索,我将不胜感激

我有

parms = {'verbose': 2}

lstm_params = {}


def get_rnn(return_sequences= True): 
    return LSTM(dim, return_sequences=return_sequences, recurrent_dropout=0.2, implementation=1, dropout=0.2)



inp = Input((maxlen_p,))
x = Embedding(input_vocab_size, 60)(inp)

x = Bidirectional(get_rnn())(x)
x = get_rnn(False)(x)

x = RepeatVector(maxlen)(x)
x = get_rnn()(x)
x = get_rnn()(x)
x = TimeDistributed(Dense(output_vocab_size, activation='softmax'))(x)


model = Model(inp, x)

model.compile(loss='sparse_categorical_crossentropy', optimizer='Adam', metrics=['acc'])

hist=model.fit(input_train, np.expand_dims(labels_train,-1), 
          validation_data=[input_test, np.expand_dims(labels_test,-1)], 
          batch_size=128, **parms, epochs=3)

SVG(model_to_dot(model).create(prog='dot', format='svg'))
在此函数中,存在以下问题:

def eval_keras(input):

    preds = model.predict(input, batch_size=128)
    predict = np.argmax(preds, axis = 2) 
    return (np.mean([all(real==p) for real, p in zip(labels_test, predict)]), predict)
preds:

[[[9.1350e-02 4.3054e-04 7.0428e-04 ... 5.0601e-04 7.1275e-04 5.8476e-03]
  [5.9895e-01 3.5628e-05 7.1672e-05 ... 4.2559e-05 7.7454e-05 3.3954e-03]
  [7.2249e-01 1.6146e-05 3.3864e-05 ... 2.0008e-05 3.8247e-05 2.3551e-03] ...
预测:

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
[ 24 360  60  80 585  73 706 595 766 625 240 284   8   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0]
真实(值):

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
[ 24 360  60  80 585  73 706 595 766 625 240 284   8   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0]

不幸的是,你们的班级非常不平衡。尝试使用
class_权重
,以补偿大多数示例等于
0
的事实。