Tensorflow 在使用深度学习的多标签分类中,准确度结果不符合逻辑
我有一个多标签分类问题,我使用了以下代码,但验证精度不符合逻辑,精度低于20%。在相同的数据条件下,采用机器学习的方法,其结果超过80%。这里可能有什么问题。 这是代码Tensorflow 在使用深度学习的多标签分类中,准确度结果不符合逻辑,tensorflow,keras,deep-learning,data-science,multilabel-classification,Tensorflow,Keras,Deep Learning,Data Science,Multilabel Classification,我有一个多标签分类问题,我使用了以下代码,但验证精度不符合逻辑,精度低于20%。在相同的数据条件下,采用机器学习的方法,其结果超过80%。这里可能有什么问题。 这是代码 data_path = "TestData.csv" data_raw = pd.read_csv(data_path) data_raw.shape #(156004, 9) categories = list(data_raw.columns.values) categories = categories
data_path = "TestData.csv"
data_raw = pd.read_csv(data_path)
data_raw.shape #(156004, 9)
categories = list(data_raw.columns.values)
categories = categories[1:]#for the labels
data = data_raw
data.shape #(156004, 9)
num_words = 20000
max_features = 150000
max_len = 200
embedding_dims = 128
num_epochs = 5
X_train = data["text"].values
X_test = data["text"].values
#Tokenization
tokenizer = tokenizer = Tokenizer(num_words)
tokenizer.fit_on_texts(list(X_train))
X_train = tokenizer.texts_to_sequences(X_train)
X_test = tokenizer.texts_to_sequences(X_test)
X_train = sequence.pad_sequences(X_train, max_len)
X_test = sequence.pad_sequences(X_test, max_len)
y_train = data[categories].values
y_test = data[categories].values
X_tra, X_val, y_tra, y_val = train_test_split(X_train, y_train)#, test_size =0.2, random_state=0)
CNN_model = Sequential([
Embedding(input_dim=max_features, input_length=max_len, output_dim=embedding_dims),
SpatialDropout1D(0.5),
Conv1D(filters=20, kernel_size=8, padding='same', activation='relu'),
BatchNormalization(),
Dropout(0.5),
GlobalMaxPool1D(),
Dense(8, activation = 'sigmoid')
def mean_pred(y_true, y_pred):
return K.mean(y_pred)
from tensorflow.python.keras.optimizer_v2.adam import Adam
adam = Adam()
CNN_model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
for category in categories:
print('**Processing {} ...**'.format(category))
pred = CNN_model.fit(X_tra, y_tra, batch_size=128, epochs=5, validation_data=(X_val, y_val))
标准的keras精度不支持多标签分类。@Dr.Snoopy,但我在另一个数据集中应用了此代码(有毒),并且它工作正常。@Dr.Snoopy在这种情况下如何获得正确的精度?