为什么这个简单的keras 3类分类器只预测一个类而不是其他类?

为什么这个简单的keras 3类分类器只预测一个类而不是其他类?,keras,deep-learning,Keras,Deep Learning,我正在尝试使用keras创建一个简单的3类深度学习分类器,如下所示: clf = Sequential() clf.add(Dense(20, activation='relu', input_dim=NUM_OF_FEATURES)) clf.add(Dense(10, activation='relu')) clf.add(Dense(3, activation='relu')) clf.add(Dense(1, activation='softmax')) # Model Compil

我正在尝试使用keras创建一个简单的3类深度学习分类器,如下所示:

clf = Sequential()
clf.add(Dense(20, activation='relu', input_dim=NUM_OF_FEATURES))
clf.add(Dense(10, activation='relu'))
clf.add(Dense(3, activation='relu'))
clf.add(Dense(1, activation='softmax'))


# Model Compilation
clf.compile(optimizer = 'adam',
             loss = 'categorical_crossentropy',
             metrics = ['accuracy'])

# Training the model
clf.fit(X_train,
         y_train,
         epochs=10,
         batch_size=16,
         validation_data=(X_val, y_val))
如何在进行预测训练后,始终只预测3个班中的同一个班(1班)

我的网络架构不正确吗

我对深度学习和人工智能一无所知


提前感谢。

如果您希望网络将三个类进行分类,那么最后一个密集层应该有三个输出节点。在该示例中,最后一个密集层有一个输出节点

clf = Sequential()
clf.add(Dense(20, activation='relu', input_dim=NUM_OF_FEATURES))
clf.add(Dense(10, activation='relu'))
clf.add(Dense(3, activation='relu'))
clf.add(Dense(3, activation='softmax'))
对于每个输入样本,输出将是三个值,所有值相加为一。这些表示输入属于每个类的概率

关于损失函数,如果你想使用交叉熵,你可以选择稀疏分类交叉熵和分类交叉熵。后者期望地面真相标签是一个热编码的(您可以使用
tf.one\u hot
)。换句话说,标签的形状与网络输出的形状相同。另一方面,稀疏分类交叉熵期望标签的秩为N-1,其中N是神经网络输出的秩。在顺序词中,这些是热编码之前的标签

当模型用于推理时,可以使用预测的最后一个维度的argmax检索预测的类值

predictions = clf.predict(x)
classes = predictions.argmax(-1)

请提供编译模型的代码部分。@Chicodelarosa已更新,谢谢,但最后一个密集层有3个输出给了我以下错误:
ValueError:Shapes(None,1)和(None,3)不兼容
,这可能是由于函数丢失造成的。您的
y\u序列的值是否为整数0、1和2?如果是这样,您应该使用
sparse\u categorical\u crossentropy
作为损失函数
categorical\u crossentropy
需要一个热编码标签,而
sparse\u categorical\u crossentropy
需要非一个热标签。哪个更好
sparse\u Category\u crossentropy
Category\u crossentropy
?它们计算相同的损失(即交叉熵),但期望不同形式的标签
Category\u crossentropy
需要一个热编码标签(在您的例子中是2D),而
sparse\u Category\u crossentropy
需要非一个热编码标签(在您的例子中是1D)。好的,但是
sparse\u Category\u crossentropy
会生成一个浮点值矩阵
(数组([0.7737731,0.19874486,0.02748196]),[0.86113244,0.12877706,0.01009052],[0.9258943,0.07132111,0.00278455],…,[0.8239085,0.15966693,0.01642458],[0.80182225,0.1772431,0.02093465],[0.9042275,0.0910617,0.00471088],[dtype=float32])
。我想得到预测的类标签为0、1或2。