Python 使用自己的csv文件进行批处理
我有一个使用Tensorflow的MNIST数据集的python代码。 会议安排如下:Python 使用自己的csv文件进行批处理,python,csv,tensorflow,neural-network,batch-processing,Python,Csv,Tensorflow,Neural Network,Batch Processing,我有一个使用Tensorflow的MNIST数据集的python代码。 会议安排如下: with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for epoch in range(hm_epochs): epoch_loss = 0 for _ in range(int(mnist.train.num_examples / bat
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for epoch in range(hm_epochs):
epoch_loss = 0
for _ in range(int(mnist.train.num_examples / batch_size)):
epoch_x, epoch_y = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
epoch_loss += c
print('Epoch: ', epoch, ' completed out of: ', hm_epochs, ' loss: ', epoch_loss)
correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
print('Accuracy:', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
该行:
epoch_x, epoch_y = mnist.train.next_batch(batch_size)
每次都要生产100号的新批次
我的问题是,如果我有自己的CSV文件(这是一个列表),如何用一个新行替换这一行,为我创建新的批?
我当前的代码如下所示:
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for epoch in range(hm_epochs):
epoch_loss = 0
for _ in range(len(training_data_list) // batch_size):
epoch_x, epoch_y = training_data_list.nextbatch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
epoch_loss += c
print('Epoch: ', epoch, ' completed out of: ', hm_epochs, ' loss: ', epoch_loss)
correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
print('Accuracy:', accuracy.eval({x: inputs, y: targets}))
其中“nextbatch”是我定义的函数。但我得到了以下错误:
AttributeError: 'list' object has no attribute 'nextbatch'
谢谢你的建议:D
顺便说一下,“培训数据列表”来自:
stops = open('.../Desktop/stops_train.csv', 'r')
training_data_list = stops.readlines()
stops.close()
您需要实现一个处理索引的对象。 您需要在该对象内实现nextbatch函数。
您可以在mnist中查看nextbatch的实现。谢谢您的评论。同时,我使用了这个:对于范围内的(len(培训数据列表)):批处理x=输入列[*批处理大小:(+1)*批处理大小]批处理y=目标列[*批处理大小:(+1)*批处理大小]。