我在pytorch上使用model.train()和model.eval()时出错

我在pytorch上使用model.train()和model.eval()时出错,pytorch,transformer,Pytorch,Transformer,我已经准备好了特征和它们的标签;我想建立一个由变压器编码器构建的模型,然后添加一个线性层来预测一个值。但是,当我在训练后使用该模型进行预测时,我得到了一些错误。 首先,我运行以下代码: import torch from torch import nn features = torch.rand(bach_size, channels, lenght) labels = torch.rand(batch_size) class TransformerModel(nn.Module):

我已经准备好了特征和它们的标签;我想建立一个由变压器编码器构建的模型,然后添加一个线性层来预测一个值。但是,当我在训练后使用该模型进行预测时,我得到了一些错误。 首先,我运行以下代码:

import torch
from torch import nn

features = torch.rand(bach_size, channels, lenght)
labels = torch.rand(batch_size)

class TransformerModel(nn.Module):

  def __init__(self):
    super(TransformerModel, self).__init__()
    encoder_layer = nn.TransformerEncoderLayer(d_model=8, nhead=8, dropout=0.5)
    self.transformer_encoder = nn.TransformerEncoder(encoder_layer, 6)
    self.decoder = nn.Linear(40, 1)

  def forward(self, src):
    encoded = self.transformer_encoder(src.transpose(1, 0)).transpose(1, 0)
    pred = self.decoder(encoded.reshape(encoded.shape[0], -1))
    return pred


model = TransformerModel()
criterion = nn.MSELoss()
lr = 0.3 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


def train():

  model.train() # Turn on the train mode
  optimizer.zero_grad()

  output = model(features)

  loss = criterion(output.view(-1, 1), labels.view(-1, 1))
  loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
  optimizer.step()

  return loss.item()



for _ in range(100):
  train()

之后,我通过以下代码预测特征:

model.eval()
output = model(features)

我得到的“output”的所有值都是相同的,如果使用“model.train()”,“output”看起来还可以;那么问题是什么呢?或者模型构建错误?

您能提供输出的小示例吗?(无需复制粘贴整个输出,因为它太大,无法在此处进行拟合)输出值始终与blow相同:张量([[0.4938],[0.4938],…[0.4938],…[0.4938],[0.4938],[grad_fn=)我感到奇怪的是,您使用非常有限的数据训练了相当大的模型。您的数据集就是您的批处理。这将导致您的模型在数据上过度拟合。你试过使用更多的数据吗?事实上,我使用了很多数据,它的形状是[3956,5,8];为了简单地解释这个问题,我发布了一些数据。你们能提供一些输出的小例子吗?(无需复制粘贴整个输出,因为它太大,无法在此处进行拟合)输出值始终与blow相同:张量([[0.4938],[0.4938],…[0.4938],…[0.4938],[0.4938],[grad_fn=)我感到奇怪的是,您使用非常有限的数据训练了相当大的模型。您的数据集就是您的批处理。这将导致您的模型在数据上过度拟合。你试过使用更多的数据吗?事实上,我使用了很多数据,它的形状是[3956,5,8];为了简单地解释这个问题,我发布了一些数据。