Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/334.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python递归函数失败_Python_Recursion_Neural Network_Pytorch - Fatal编程技术网

Python递归函数失败

Python递归函数失败,python,recursion,neural-network,pytorch,Python,Recursion,Neural Network,Pytorch,我遇到的问题真的很奇怪 我试图完成的是:我正在使用pytorch训练一个神经网络,如果训练损失没有减少,我想重新启动我的训练函数,以便用一组不同的权重重新初始化神经网络。培训功能如下所示: def __train__(dp, i, j, net, restarts, epoch=0): if net == '2CH': model = TwoChannelCNN().cuda() elif net == 'Siam' : model = SiameseCNN().cuda()

我遇到的问题真的很奇怪

我试图完成的是:我正在使用pytorch训练一个神经网络,如果训练损失没有减少,我想重新启动我的训练函数,以便用一组不同的权重重新初始化神经网络。培训功能如下所示:

def __train__(dp, i, j, net, restarts, epoch=0):
    if net == '2CH': model = TwoChannelCNN().cuda()
    elif net == 'Siam' : model = SiameseCNN().cuda()
    elif net == 'Trad' : model = TraditionalCNN().cuda()
    ls_fn = torch.nn.MSELoss(reduce=True)
    optim = torch.optim.SGD(model.parameters(),  lr=1e-6, momentum=0.9)
    epochs = np.arange(100)
    eloss = []
    for epoch in epochs:
        model.train()
        train_loss = []
        tr_batches = np.array_split(dp.train_set, int(len(dp.train_set)/8))
        for tr_batch in tr_batches:
            if net == '2CH': loaded_batch = dp.__load2CH__(tr_batch)
            elif net == 'Siam': loaded_batch = dp.__loadSiam__(tr_batch)
            elif net == 'Trad' : loaded_batch = dp.__load__(tr_batch, i)
            for x_batch, y_batch in loaded_batch:
                x_var, y_var = Variable(x_batch.cuda()), Variable(y_batch.cuda())
                y_pred = torch.clamp(model(x_var), 0, 1)
                loss = ls_fn(y_pred, y_var)
                train_loss.append(abs(loss.item()))
                optim.zero_grad()
                loss.backward()
                optim.step()
        eloss.append(np.mean(train_loss))
        print(epoch, np.mean(train_loss))
        if epoch == 10 and np.mean(train_loss) > 0.2:
            restarts += 1
            print('Number of restarts for client {} and fold {}: {}'.format(i,j,restarts))
            __train__(dp, i, j, net, restarts, epoch=0)

    __plotLoss__(epochs, eloss, 'train', str(i), str(j))
    torch.save(model.state_dict(), "Output/client_{}_fold_{}.pt".format(i, j))
因此,如果epoch==10且np.mean(train_loss)>0.2,则基于
的重新启动是有效的,但只是有时有效,这超出了我的理解。以下是一个输出示例:

0 0.5000133737921715
1 0.4999906486272812
2 0.464298670232296
3 0.2727506290078163
4 0.2628978116512299
5 0.2588871221542358
6 0.25728522151708605
7 0.25630473804473874
8 0.2556223524808884
9 0.25522999209165576
10 0.25467908215522767
Number of restarts for client 5 and fold 1: 3
0 0.10957609283713009
1 0.02840371729924134
2 0.021477583368030594
3 0.017759160268232682
4 0.015173796122947827
5 0.013349939693290782
6 0.011949078906879265
7 0.010810676779671655
8 0.00987362345259362
9 0.009110640348696108
10 0.008239036202623808
11 0.007680381585537574
12 0.007171026876221333
13 0.006765962297888837
14 0.006428168776848068
15 0.006133011780953467
16 0.005819878347673745
17 0.005572605537395361
18 0.00535818950227004
19 0.005159409143814457
20 0.0049763926251294235
21 0.004738794513338235
22 0.004578812885309958
23 0.004428663117960554
24 0.004282198464788351
25 0.004145324644400691
26 0.004018862769889626
27 0.0039044404603504573
28 0.0037960831121495744
29 0.0036947361258523586
30 0.0035982220717533267
31 0.0035018146670104723
32 0.0034150678806059887
33 0.0033372560733512698
34 0.003261332974241583
35 0.00318166259540763
36 0.003108531899014735
37 0.0030385089141125848
38 0.002977990984523103
39 0.0029195284016142937
40 0.002870084639441188
41 0.0028180573325994373
42 0.0027717544270049643
43 0.002719321814503495
44 0.0026704726860933194
45 0.0026204266263459316
46 0.002570544072460258
47 0.0025225681523167224
48 0.0024814611543610746
49 0.0024358948737413116
50 0.002398673941639636
51 0.0023606415423654587
52 0.002330436484101057
53 0.0022891738560574027
54 0.002260655496376241
55 0.002227568955708719
56 0.002191826719741698
57 0.0021609061182290058
58 0.0021279943092100666
59 0.0020966088490456513
60 0.002066195117003474
61 0.0020381672924407895
62 0.002009863329306995
63 0.001986304977759602
64 0.0019564831849032487
65 0.0019351609173580756
66 0.0019077356409993626
67 0.0018875047204855945
68 0.0018617453310780547
69 0.001839518720600381
70 0.001815563331498197
71 0.0017149778925132932
72 0.0016894878409248121
73 0.0016652211918212743
74 0.0016422999463582074
75 0.0016183732903472788
76 0.0015962369183098418
77 0.0015757764620279887
78 0.0015542267022799728
79 0.0015323152910759318
80 0.0014337954093957706
81 0.001410489170542867
82 0.0013871921329466962
83 0.0013641994057461773
84 0.001345829172682187
85 0.001322142209181493
86 0.00130379223035348
87 0.001282231878045458
88 0.001263879886683956
89 0.001243419097817167
90 0.0012279346547037929
91 0.001206978429649382
92 0.0011871445969959496
93 0.001172510546330841
94 0.0011529557384797045
95 0.0011350733004023273
96 0.001118382818282214
97 0.001103347793609089
98 0.0010848538354748599
99 0.0010698940242660911
11 0.2542190085053444
12 0.2538975296020508
在这里,您可以看到,从第三次重新启动开始,重新启动是正确的,但是,由于网络聚合,训练应该完成,但是函数在第99个历元之后再次重新启动(原因未知),并且不知何故在第11个历元开始,这也没有意义,因为每当函数启动或重新启动时,我都显式地指定
epoch=0
。我还应该补充一点,有时,函数会在epoch 99之后正确完成,此时已实现收敛,并且不会重新启动


所以我的问题是,为什么这段代码会产生不一致的结果和结果?我错过了什么?提前感谢您的建议。

如果epoch==10且为np,则第二次调用
\uuuu train\uuuu
将重新启动培训。平均值(train\u loss)>0.2:
,但您永远不会终止第一个循环。 因此,在第二次训练收敛后,外循环在第11纪元继续


您需要的是一个
break
语句,该语句位于对
\uuuuuuuuuuuuu\uuuuuuuuu

的内部调用之后。这里的一个混淆源是,您以两种不同的方式使用名为
epoch
的变量:1。作为传递给方法
\uuuu列车\uuuu
和2的参数。作为循环变量
用于历元中的历元:
。在循环的每个开始处,
epoch
将重置为0,因此传递给该方法的值将无关紧要。@user727089感谢您指出这一点。然而,问题仍然在于,即使指定了
epochs=np.arange(100)
,为什么epochs有时从11开始。这意味着,每次调用
\uuuu train\uuuu
,历代都应该从零开始?你说得对。每次训练总是从第0纪元开始。只要确保在满足中止条件后终止培训(参见下面我的答案)。在内部调用
\uuuuuu train\uuuu
之后,返回语句将如何到达?你是说一个空的返回语句吗?我目前正在尝试返回火车,但不知道这是否正确?我只是很难理解调用
\uuu train\uuu
后返回语句将如何到达?是的,你是对的。您需要更多的逻辑来处理输出,例如,如果要重新启动培训,则将
聚合的
变量设置为
False
。然后你可以在最后检查你是否应该打印输出。好吧,当我返回时,它似乎得到了解决。。。在非收敛性方面有一些重新开始,在错误的年代等方面没有问题。我认为,
break
也会起作用,尽管。。。非常感谢你的帮助。如果出现问题,我会再次发表评论,但到目前为止,它看起来很完美:D