Python 我的神经网络模型有什么问题?

Python 我的神经网络模型有什么问题?,python,machine-learning,tensorflow,deep-learning,Python,Machine Learning,Tensorflow,Deep Learning,我得到了一个178个元素的数据集,每个元素包含13个特征和1个标签。 标签存储为一个热数组。我的训练数据集由158个元素组成 以下是我的模型的外观: x = tf.placeholder(tf.float32, [None,training_data.shape[1]]) y_ = tf.placeholder(tf.float32, [None,training_data_labels.shape[1]]) node_1 = 300 node_2 = 300 node_3 = 300 out

我得到了一个178个元素的数据集,每个元素包含13个特征和1个标签。 标签存储为一个热数组。我的训练数据集由158个元素组成

以下是我的模型的外观:

x = tf.placeholder(tf.float32, [None,training_data.shape[1]])
y_ = tf.placeholder(tf.float32, [None,training_data_labels.shape[1]])

node_1 = 300
node_2 = 300
node_3 = 300
out_n = 3   

#1
W1 = tf.Variable(tf.random_normal([training_data.shape[1], node_1]))
B1 = tf.Variable(tf.random_normal([node_1]))
y1 = tf.add(tf.matmul(x,W1),B1)
y1 = tf.nn.relu(y1)

#2
W2 = tf.Variable(tf.random_normal([node_1, node_2]))
B2 = tf.Variable(tf.random_normal([node_2]))
y2 = tf.add(tf.matmul(y1,W2),B2)
y2 = tf.nn.relu(y2)

#3
W3 = tf.Variable(tf.random_normal([node_2, node_3]))
B3 = tf.Variable(tf.random_normal([node_3]))
y3 = tf.add(tf.matmul(y2,W3),B3)
y3 = tf.nn.relu(y3)

#output
W4 = tf.Variable(tf.random_normal([node_3, out_n]))
B4 = tf.Variable(tf.random_normal([out_n]))
y4 = tf.add(tf.matmul(y3,W4),B4)
y = tf.nn.softmax(y4)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(200):
        sess.run(optimizer,feed_dict={x:training_data, y_:training_data_labels})

    correct = tf.equal(tf.argmax(y_, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
    print('Accuracy:',accuracy.eval({x:eval_data, y_:eval_data_labels}))
但是准确度很低,我尝试将范围200增加到更高的数字,但仍然很低


我可以做些什么来改善结果

问题在于,您使用的是
y4的softmax,然后将其传递给
tf.nn.softmax\u cross\u entropy\u和\u logits
。此错误非常常见,文档中实际上有一条关于此错误的注释:


剩下的代码看起来很好,因此只需将
y4
替换为
y
,并去掉
y=tf.nn.softmax(y4)
,初始化很可能会出现问题。提供您的培训数据以进行复制this@Skam这是一个家庭作业。为什么你需要argmax方法?我认为您应该删除这两个argmax方法,然后您就可以开始了
WARNING: This op expects unscaled logits, since it performs a softmax on logits internally 
for efficiency. Do not call this op with the output of softmax, as it will produce 
incorrect results.