Pytorch 卷积网络实现中的维数误差

Pytorch 卷积网络实现中的维数误差,pytorch,conv-neural-network,Pytorch,Conv Neural Network,我试图理解为什么我的分类器存在维度问题。这是我的密码: class convnet(nn.Module): def __init__(self, num_classes=1000): super(convnet, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),

我试图理解为什么我的分类器存在维度问题。这是我的密码:

class convnet(nn.Module):

    def __init__(self, num_classes=1000):
        super(convnet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride = 2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride = 2), #stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride = 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(576, 128),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Linear(64,num_classes),
            nn.Softmax(),
       )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x,1) #x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x


def neuralnet(num_classes,**kwargs):
    model = convnet(**kwargs)
    return model

所以这里我的问题是:预期4D输入(得到2D输入)

我很确定这个错误是由flatte命令引起的,但是我真的不明白为什么,因为分类器有完全密集的连接。如果有人知道我哪里出了问题,那将非常有帮助


谢谢

展平后,分类器的输入有两个维度(大小:[批次大小,576]),因此第一个线性层的输出也有两个维度(大小:[批次大小,128])。然后将该输出传递给,这要求其输入具有4个维度(大小:[批次大小、通道、高度、宽度])

如果要在二维输入上使用批处理规范,则需要使用,它接受三维输入(大小:[批处理大小,通道,长度])或二维输入(大小:[批处理大小,长度])

self.classifier=nn.Sequential(
nn.线性(576128),
nn.1D(128),
nn.ReLU(就地=真),
nn.线性(128,64),
nn.ReLU(就地=真),
nn.1D(64),
nn.线性(64个,num_类),
nn.Softmax(),
)

您的错误指向哪一行?它指向x=self.classifier(x),您的意思是
x=self.classifier(x)
提示:始终发布完整的堆栈跟踪我将在下次记住这一点!我明白了!谢谢你的详细回复!