Python 限制MNIST训练数据的大小

Python 限制MNIST训练数据的大小,python,tensorflow,mnist,Python,Tensorflow,Mnist,我刚刚开始学习python和TensorFlow,正在尝试各种神经网络和MNIST数据。我想做的一个实验是看看训练集的大小如何影响性能。目前,培训集中似乎有55000个输入/输出对。我想通过某种方式限制培训只使用前1000个左右,但不知道如何实现这一点 我当前的培训功能如下所示: def do_training(): print("Train entry") for i in range(2000): batch_of_training_inputs, batc

我刚刚开始学习python和TensorFlow,正在尝试各种神经网络和MNIST数据。我想做的一个实验是看看训练集的大小如何影响性能。目前,培训集中似乎有55000个输入/输出对。我想通过某种方式限制培训只使用前1000个左右,但不知道如何实现这一点

我当前的培训功能如下所示:

def do_training():
    print("Train entry")
    for i in range(2000):

        batch_of_training_inputs, batch_of_training_labels = mnist.train.next_batch(100)

        sess.run(train_step, feed_dict={generic_image_data_struct: batch_of_training_inputs, target_for_output_struct: batch_of_training_labels })
有没有像

mnist.train.next_batch(100, BUT_ONLY_FROM_FIRST(1000))
仅供参考,我收到带有以下代码的mnist:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

您可以做的一件简单的事情就是增加验证数据集的大小。MNIST包含60000张图像,因此如果您只想在1000张图像上进行训练,您可以执行以下操作:

mnist = input_data.read_data_sets(train_dir, one_hot=True, validation_size=59000)

经过一点黑客攻击,我认为这可能会奏效。虽然我确实不建议将来依赖这个解决方案,因为它依赖于
DataSet的内部实现。对于一个快速的实验,它可能是好的

来自tensorflow.examples.tutorials.mnist导入输入数据
从tensorflow.contrib.learn.python.learn.datasets.mnist导入数据集
从tensorflow.python.framework导入数据类型
mnist=输入数据。读取数据集('mnist\U数据',one\U hot=真)
train_small=数据集(mnist.train.images[:1000],mnist.train.labels[:1000],
dtype=dtypes.uint8,重塑=False,种子=None)