Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/324.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python TensorFlow从mnist数据集中选择标签_Python_Tensorflow - Fatal编程技术网

Python TensorFlow从mnist数据集中选择标签

Python TensorFlow从mnist数据集中选择标签,python,tensorflow,Python,Tensorflow,我正在使用tensorflow.examples.tutorials.mnist来训练具有5个隐藏层的nn 这是我训练神经网络的方法: with tf.Session() as sess: init.run() for epoch in range(n_epochs): for iteration in range(len(mnist.test.labels)//batch_size): X_batch, y_batch = mnist.train.next_batch(

我正在使用tensorflow.examples.tutorials.mnist来训练具有5个隐藏层的nn

这是我训练神经网络的方法:

with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
    for iteration in range(len(mnist.test.labels)//batch_size):
        X_batch, y_batch = mnist.train.next_batch(batch_size)
        sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
    acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
    acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
    print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
我想训练神经网络只识别从0到4的数字。我将logits层更改为有5个输出


如何过滤TensorFlow提供的mnist数据集,以便只获取0到4之间的数字?

有很多方法可以做到这一点。其中之一是当您提取
X\u批时,y\u批=mnist.train.next\u批(批大小)
。在这一步中,您的
y_批
将具有有关数字值的信息(数字值或数字的一个热点)

您迭代批处理中的示例,并检查数字是否是您关心的数字。如果是,则将其添加到
清理批中。效率不是很高,但它会起作用


对评论的答复:


这是没有效率的,因为您可能需要多次过滤相同的数据。我不认为这会是一个问题,因为MNIST是超小的。通常的方法是只过滤一次,创建一个新的数据集,然后编写自己的函数从中获取下一批数据(实际上很容易,因为您只需从数据集中随机选择k个元素)

谢谢,这是一种方法,但您是对的,不是很有效,您知道其他方法吗?我在想我可能可以将一些参数传递给下一批,但找不到