Python Theano小批量迭代器不工作

Python Theano小批量迭代器不工作,python,theano,lasagne,Python,Theano,Lasagne,无小批量迭代器不工作 我编写了一个小批量迭代器,从我的神经网络中得到预测结果。 然而,我做了一些测试,发现了一些错误 基本上: If batch_size > amount of inputs : error 我制作了一个脚本来在代码中显示这个bug。其结果如下: import numpy as np def minibatch_iterator_predictor(inputs, batch_size): assert len(inputs) > 0 for

无小批量迭代器不工作

我编写了一个小批量迭代器,从我的神经网络中得到预测结果。 然而,我做了一些测试,发现了一些错误

基本上:

If batch_size > amount of inputs  : error
我制作了一个脚本来在代码中显示这个bug。其结果如下:

import numpy as np

def minibatch_iterator_predictor(inputs, batch_size):
    assert len(inputs) > 0

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = slice(start_idx, start_idx + batch_size)
        yield inputs[excerpt]


def test(x, batch_size):
    prediction = np.empty((x.shape[0], 2), dtype=np.float32)

    index = 0
    for batch in minibatch_iterator_predictor(inputs=x, batch_size=batch_size):
        inputs = batch

        # y = self.predict_function(inputs)
        y = inputs

        prediction[index * batch_size:batch_size * (index + 1), :] = y[:]
        index += 1
    return prediction

######################################
#TEST SCRIPT
######################################

#Input
arr = np.zeros(shape=(10, 2))

arr[0] = [1, 0]
arr[1] = [2, 0]
arr[2] = [3, 0]
arr[3] = [4, 0]
arr[4] = [5, 0]
arr[5] = [6, 0]
arr[6] = [7, 0]
arr[7] = [8, 0]
arr[8] = [9, 0]
arr[9] = [10, 0]

###############################################

batch_size = 5
print "\nBatch_size ", batch_size
r = test(x=arr, batch_size=batch_size)

#Debug
for k in xrange(r.shape[0]):
        print str(k) + " : " + str(r[k])

##Assert

assert arr.shape[0] == r.shape[0]

for k in xrange(0,r.shape[0]):
    print r[k] == arr[k]
以下是测试 对于批量大小=10的情况:

Batch_size  10
0 : [ 1.  0.]
1 : [ 2.  0.]
2 : [ 3.  0.]
3 : [ 4.  0.]
4 : [ 5.  0.]
5 : [ 6.  0.]
6 : [ 7.  0.]
7 : [ 8.  0.]
8 : [ 9.  0.]
9 : [ 10.   0.]
对于批量大小=11的情况:

0 : [  1.13876845e-37   0.00000000e+00]
1 : [  1.14048027e-37   0.00000000e+00]
2 : [  1.14048745e-37   0.00000000e+00]
3 : [  9.65151604e-38   0.00000000e+00]
4 : [  1.14002468e-37   0.00000000e+00]
5 : [  1.14340036e-37   0.00000000e+00]
6 : [  1.14343264e-37   0.00000000e+00]
7 : [  8.02794698e-38   0.00000000e+00]
8 : [  8.02794698e-38   0.00000000e+00]
9 : [  8.02794698e-38   0.00000000e+00]
适用于批量大小为12的产品

0 : [  1.13876845e-37   0.00000000e+00]
1 : [  1.14048027e-37   0.00000000e+00]
2 : [  1.14048745e-37   0.00000000e+00]
3 : [  9.65151604e-38   0.00000000e+00]
4 : [  1.14002468e-37   0.00000000e+00]
5 : [  1.14340036e-37   0.00000000e+00]
6 : [  1.14343264e-37   0.00000000e+00]
7 : [  8.10141537e-38   0.00000000e+00]
8 : [  8.10141537e-38   0.00000000e+00]
9 : [  8.10141537e-38   0.00000000e+00]

我如何解决这个问题?

请尽量在问题中更具体一些。你到底想修什么

没有任何错误。 当批大小大于输入时,函数
minibatch\u iterator\u predictor
将生成一个空迭代器,并且不会执行minibatch\u iterator\u predictor中批的循环
(输入=x,批大小=批大小)

当批处理大小大于输入数量时,得到的只是初始化后的零:
prediction=np.empty((x.shape[0],2),dtype=np.float32)

您可以将最大批量大小限制为输入的数量:

def minibatch_iterator_predictor(inputs, batch_size):
    assert len(inputs) > 0
    if batch_size > len(inputs):
        batch_size = len(inputs)

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = slice(start_idx, start_idx + batch_size)
        yield inputs[excerpt]