Pytorch-关于简单问题的批量规范化
我实现了一个具有批量规范化的模型:Pytorch-关于简单问题的批量规范化,pytorch,batch-normalization,Pytorch,Batch Normalization,我实现了一个具有批量规范化的模型: class FFNet(torch.nn.Module): def __init__(self, D_in, H_1, H_2, D_out): super(FFNet, self).__init__() self.linear1 = torch.nn.Linear(D_in, H_1) self.linear2 = torch.nn.Linear(H_1, H_2) self.bn2
class FFNet(torch.nn.Module):
def __init__(self, D_in, H_1, H_2, D_out):
super(FFNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H_1)
self.linear2 = torch.nn.Linear(H_1, H_2)
self.bn2 = torch.nn.BatchNorm1d(H_2)
self.linear4 = torch.nn.Linear(H_2, D_out)
def forward(self, x):
h_relu_1=F.relu(self.linear1(x))
h_relu_2=F.relu(self.bn2(self.linear2(h_relu_1)))
y_pred=self.linear4(h_relu_2)
return y_pred
此外,我还编写了培训循环:
for epoch in range(epoches):
running_loss = 0.0
cnt = 0
for i, data in enumerate(train_data, 0):
local_X, local_y = data
y_pred = model.forward(local_X)
loss = criterion(y_pred, local_y)
optimizer.zero_grad()
#loss = criterion(y_pred, Y_local_output)
loss.backward() # back props
optimizer.step()
running_loss = running_loss + loss.item()
cnt+=1
Validation_loss = 0.0
cnt2 = 0
# Validation
for i, data in enumerate(validation_data, 0):
Val_X, Val_Y = data
y_pred = model.forward(Val_X)
loss=criterion(y_pred, Val_Y)
Validation_loss = Validation_loss + loss.item()
cnt2+=1
我有两个问题:
1.在这段代码中是否不需要使用model.train()?
2.如何使用eval
评估此模型?我有一个数据样本,其大小为(1xD_英寸),批量大小大于1。当我使用以下代码时,出现了一个错误:
test_single = torch.tensor([aa, ab, ac, ad, ae, af, ag])
test_single = test_single.unsqueeze(0)
model.eval()
[bb,cc] = model.forward(test_single)
错误是“没有足够的值来解包(预期为2,得到1)”如果有批标准化,则在分别进行培训和评估时确实需要使用model.train()和model.eval() 第二部分(不奇怪的代码)没有错。但是,您的模型只有一个输出(请参阅模型的forward函数的return语句),这会导致错误,即您试图解包2个值,而只有一个值。所以,你不能这么做
[bb,cc] = model.forward(test_single)
你必须这样做
out = model.forward(test_single)
我试过这个,效果很好。谢谢你的好意。你的意思是我必须在代码中添加句子model.train()和model.eval()?我想知道我应该在培训和验证时添加这些句子。此外,该模型有两个输出。但是,这种错误是存在的……是的,在启动for循环进行训练之前,必须添加
model.train()
。验证和model.eval()相同。对于问题中显示的模型,forward函数只返回一个值。因此,当您调用forward函数(使用model(input)
)时,您不能期望它返回两个输出。