Python 如何阅读此修改的unet?

Python 如何阅读此修改的unet?,python,deep-learning,pytorch,image-segmentation,Python,Deep Learning,Pytorch,Image Segmentation,此代码是我正在处理的修改后的UNet。我面临着难以阅读和理解的代码,以及如何将跳过连接连接到上采样。有人能解释一下吗?或者可以不用nn.ModuleList以更简单易懂的方式编写吗 有人能用图表展示一下这个网络的样子吗 这是我获取这段代码并试图理解它的github repo链接。这里是一个与主模型forward(x)方法等效的功能。它要详细得多,但它正在“分解”操作流程,使其更容易理解 我假设列表参数的长度总是5(我在[0,4]范围内,包括在内),因此我可以正确地解包(它遵循默认的参数集) 最重

此代码是我正在处理的修改后的
UNet
。我面临着难以阅读和理解的代码,以及如何将跳过连接连接到上采样。有人能解释一下吗?或者可以不用
nn.ModuleList
以更简单易懂的方式编写吗

有人能用图表展示一下这个网络的样子吗


这是我获取这段代码并试图理解它的github repo链接。

这里是一个与主模型
forward(x)
方法等效的功能。它要详细得多,但它正在“分解”操作流程,使其更容易理解

我假设列表参数的长度总是
5
(我在[0,4]范围内,包括在内),因此我可以正确地解包(它遵循默认的参数集)

最重要的两个部分是:

  • 跳过
    ,其中张量
    x
    在代码的并行部分进行处理,而不是干扰主
    x“路径”

  • 跳过
    部分产生的张量然后从最后一个开始反馈到“主路径”。我把这些张量作为单个变量
    s0到s3
    ,这样它就更明显了

  • 从这张图片中,你可以清楚地看到下半部分给后半部分喂食
    s0
    是最长的灰色箭头,它连接到最后一个卷积层组之前的“主路径”。 (不同的U形网)

    您也可以从中理解为什么不需要存储
    s4
    :它直接馈送到下一层,因此不需要将其存储为单独的变量


    模块
    版本确实存储了它,但这只是因为它方便地存储在一个列表中,该列表在末尾以相反的顺序读取。将它们存储在列表中的另一个明显原因是,通过相应地更改参数,我们可以有任意数量的
    向上
    向下
    部分。

    谢谢,我不明白为什么
    nu=[128128128128]
    nd=[128128128]
    没有减少。我在上面保留的代码中有相同的值,但我不明白为什么会是这样,这是个好问题。我以前从未使用过U型网络进行过修复(但分割),这很可能与此特定任务有关。谢谢。在
    Model\u-up
    中,为什么
    在\u通道中=132
    ?这是
    Model\u down
    的输出吗?通道是串联的。我们有128个频道来自我所说的
    “主路径”
    ,还有4个频道来自
    “跳过”
    。事实上,我在函数的上半部分#1到3犯了一个错误:我应用了与你的评论相关的错误代码(128个通道,而不是132个通道)。我更新了答案以更正它。在中的“训练模型”中,我们可以阅读
    z=(0.1)*torch.rand((1,32512512),device=“cuda”)
    z
    被送入模型,第二个dim是通道1,这解释了
    torch.cat
    调用中的参数
    axis=1
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torchvision
    from PIL import Image
    import matplotlib.pyplot as plt
    
    class Model_Down(nn.Module):
        """
        Convolutional (Downsampling) Blocks.
    
        nd = Number of Filters
        kd = Kernel size
    
        """
        def __init__(self,in_channels, nd = 128, kd = 3, padding = 1, stride = 2):
            super(Model_Down,self).__init__()
            self.padder = nn.ReflectionPad2d(padding)
            self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = nd, kernel_size = kd, stride = stride)
            self.bn1 = nn.BatchNorm2d(nd)
    
            self.conv2 = nn.Conv2d(in_channels = nd, out_channels = nd, kernel_size = kd, stride = 1)
            self.bn2 = nn.BatchNorm2d(nd)
    
            self.relu = nn.LeakyReLU()
    
        def forward(self, x):
            x = self.padder(x)
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.padder(x)
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.relu(x)
            return x
    
    class Model_Skip(nn.Module):
        """
    
        Skip Connections
    
        ns = Number of filters
        ks = Kernel size
    
        """
        def __init__(self,in_channels = 128, ns = 4, ks = 1, padding = 0, stride = 1):
            super(Model_Skip, self).__init__()
            self.conv = nn.Conv2d(in_channels = in_channels, out_channels = ns, kernel_size = ks, stride = stride, padding = padding)
            self.bn = nn.BatchNorm2d(ns)
            self.relu = nn.LeakyReLU()
    
        def forward(self,x):
            x = self.conv(x)
            x = self.bn(x)
            x = self.relu(x)
            return x
    
    
    class Model_Up(nn.Module):
        """
        Convolutional (Downsampling) Blocks.
    
        nd = Number of Filters
        kd = Kernel size
    
        """
        def __init__(self, in_channels = 132, nu = 128, ku = 3, padding = 1):
            super(Model_Up, self).__init__()
            self.bn1 = nn.BatchNorm2d(in_channels)
            self.padder = nn.ReflectionPad2d(padding)
            self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = nu, kernel_size = ku, stride = 1, padding = 0)
            self.bn2 = nn.BatchNorm2d(nu)
    
            self.conv2 =  nn.Conv2d(in_channels = nu, out_channels = nu, kernel_size = 1, stride = 1, padding = 0) #According to supmat.pdf ku = 1 for second layer
            self.bn3 = nn.BatchNorm2d(nu)
    
            self.relu = nn.LeakyReLU()
    
        def forward(self,x):
            x = self.bn1(x)
            x = self.padder(x)
            x = self.conv1(x)
            x = self.bn2(x)
            x = self.relu(x)
            x = self.conv2(x)
            x = self.bn3(x)
            x = self.relu(x)
            x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
            return x
    
    
    class Model(nn.Module):
        def __init__(self, length = 5, in_channels = 32, out_channels = 3, nu = [128,128,128,128,128] , nd =
                        [128,128,128,128,128], ns = [4,4,4,4,4], ku = [3,3,3,3,3], kd = [3,3,3,3,3], ks = [1,1,1,1,1]):
            super(Model,self).__init__()
            assert length == len(nu), 'Hyperparameters do not match network depth.'
    
            self.length = length
    
            self.downs = nn.ModuleList([Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i]) if i != 0 else
                                            Model_Down(in_channels = in_channels, nd = nd[i], kd = kd[i]) for i in range(self.length)])
    
            self.skips = nn.ModuleList([Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i]) for i in range(self.length)])
    
            self.ups = nn.ModuleList([Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i]) if i != self.length-1 else
                                            Model_Up(in_channels = ns[i], nu = nu[i], ku = ku[i]) for i in range(self.length-1,-1,-1)]) #Elements ordered backwards
    
            self.conv_out = nn.Conv2d(nu[0],out_channels,1,padding = 0)
            self.sigm = nn.Sigmoid()
    
        def forward(self,x):
            s = [] #Skip Activations
    
            #Downpass
            for i in range(self.length):
                x = self.downs[i].forward(x)
                s.append(self.skips[i].forward(x))
    
            #Uppass
            for i in range(self.length):
                if (i == 0):
                    x = self.ups[i].forward(s[-1])
                else:
                    x = self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
    
            x = self.sigm(self.conv_out(x)) #Squash to RGB ([0,1]) format
            return x
    
    
    def unet_function(x, in_channels = 32, out_channels = 3, nu = [128,128,128,128,128],
                      nd = [128,128,128,128,128], ns = [4,4,4,4,4], ku = [3,3,3,3,3],
                      kd = [3,3,3,3,3], ks = [1,1,1,1,1]):
    
    
        ################################
        # DOWN PASS ####################
        ################################
    
        #########
        # i = 0 #
        #########
    
        # First Down
        # Model_Down(in_channels = in_channels, nd = nd[i], kd = kd[i])
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2D(in_channels=in_channels, out_channels=nd[0], kernel_size=kd[0], stride=2)(x)
        x = nn.BatchNorm2d(nd[0])(x)
        x = nn.LeakyRelu()(x)
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2d(in_channels = nd[0], out_channels=nd[0], kernel_size = kd[0], stride=1)(x)
        x = nn.BatchNorm2d(nd[0])(x)
        x = nn.LeakyRelu()(x)
    
        # First skip
        # Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
        s0 = nn.Conv2D(in_channels=nd[0], out_channels=ns[0])(x)
        s0 = nn.BatchNorm2d(ns[0])(s0)
        s0 = nn.LeakyreLU()(s0)
    
    
        #########
        # i = 1 #
        #########
    
        # Second Down
        # Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2D(in_channels=nd[0], out_channels=nd[0], kernel_size=kd[1], stride=2)(x)
        x = nn.BatchNorm2d(nd[0])(x)
        x = nn.LeakyRelu()(x)
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2d(in_channels = nd[0], out_channels=nd[0], kernel_size = kd[1], stride=1)(x)
        x = nn.BatchNorm2d(nd[0])(x)
        x = nn.LeakyRelu()(x)
    
        # Second skip
        # Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
        s1 = nn.Conv2D(in_channels=nd[1], out_channels=ns[1])(x)
        s1 = nn.BatchNorm2d(ns[1])(s1)
        s1 = nn.LeakyreLU()(s1)
    
    
        #########
        # i = 2 #
        #########
    
        # Third Down
        # Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2D(in_channels=nd[1], out_channels=nd[1], kernel_size=kd[2], stride=2)(x)
        x = nn.BatchNorm2d(nd[1])(x)
        x = nn.LeakyRelu()(x)
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2d(in_channels = nd[1], out_channels=nd[0], kernel_size = kd[2], stride=1)(x)
        x = nn.BatchNorm2d(nd[1])(x)
        x = nn.LeakyRelu()(x)
    
        # Third skip
        # Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
        s2 = nn.Conv2D(in_channels=nd[2], out_channels=ns[2])(x)
        s2 = nn.BatchNorm2d(ns[2])(s2)
        s2 = nn.LeakyreLU()(s2)
    
    
        #########
        # i = 3 #
        #########
    
        # Fourth Down
        # Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2D(in_channels=nd[2], out_channels=nd[2], kernel_size=kd[3], stride=2)(x)
        x = nn.BatchNorm2d(nd[2])(x)
        x = nn.LeakyRelu()(x)
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2d(in_channels = nd[2], out_channels=nd[2], kernel_size = kd[3], stride=1)(x)
        x = nn.BatchNorm2d(nd[2])(x)
        x = nn.LeakyRelu()(x)
    
        # Fourth skip
        # Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
        s3 = nn.Conv2D(in_channels=nd[3], out_channels=ns[3])(x)
        s3 = nn.BatchNorm2d(ns[3])(s3)
        s3 = nn.LeakyreLU()(s3)
    
    
        #########
        # i = 4 #
        #########
    
        # Fifth Down
        # Model_Down(in_channels = nd[i-1], nd = nd[i], kd = kd[i])
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2D(in_channels=nd[3], out_channels=nd[3], kernel_size=kd[4], stride=2)(x)
        x = nn.BatchNorm2d(nd[3])(x)
        x = nn.LeakyRelu()(x)
        x = nn.ReflectionPad2d(padding=1)(x)
        x = nn.Conv2d(in_channels = nd[3], out_channels=nd[3], kernel_size = kd[4], stride=1)(x)
        x = nn.BatchNorm2d(nd[2])(x)
        x = nn.LeakyRelu()(x)
    
        # Fifth skip
        # Model_Skip(in_channels = nd[i], ns = ns[i], ks = ks[i])
        x = nn.Conv2D(in_channels=nd[4], out_channels=ns[4])(x)
        x = nn.BatchNorm2d(ns[4])(x)
        x = nn.LeakyreLU()(x)
    
    
    
        ################################
        # UP PASS ######################
        ################################
    
        #########
        # i = 4 #
        #########
    
        # First Up
        # Model_Up(in_channels = ns[i], nu = nu[i], ku = ku[i])
        x = nn.BatchNorm2d(in_channel=ns[4])(x)
        x = nn.ReflectionPad2d(padding)(x)
        x = nn.Conv2d(in_channels=ns[4], out_channels=nu[4], kernel_size=ku[4], stride=1, padding=0)(x)
        x = nn.BatchNorm2d(nu[4])(x)
        x = nn.LeakyReLU()(x)
        x = nn.Conv2d(in_channels = nu[4], out_channels=nu[4], kernel_size = 1, stride = 1, padding = 0)(x)
        x = nn.BatchNorm2d(nu[4])(x)
        x = nn.LeakyReLU()(x)
        x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
    
    
        #########
        # i = 3 #
        #########
    
        # Second Up
        # self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
        x = torch.cat([x,s3], axis=1) # IMPORTANT HERE
        # Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
        x = nn.BatchNorm2d(in_channel=ns[3]+nu[4])(x)
        x = nn.ReflectionPad2d(padding)(x)
        x = nn.Conv2d(in_channels=ns[3]+nu[4], out_channels=nu[3], kernel_size=ku[3], stride=1, padding=0)(x)
        x = nn.BatchNorm2d(nu[3])(x)
        x = nn.LeakyReLU()(x)
        x = nn.Conv2d(in_channels = ns[3]+nu[4], out_channels=nu[3], kernel_size = 1, stride = 1, padding = 0)(x)
        x = nn.BatchNorm2d(nu[3])(x)
        x = nn.LeakyReLU()(x)
        x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
    
    
        #########
        # i = 2 #
        #########
    
        # Third Up
        # self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
        x = torch.cat([x,s2], axis=1) # IMPORTANT HERE
        # Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
        x = nn.BatchNorm2d(in_channel=ns[2]+nu[3])(x)
        x = nn.ReflectionPad2d(padding)(x)
        x = nn.Conv2d(in_channels=ns[2]+nu[3], out_channels=nu[2], kernel_size=ku[2], stride=1, padding=0)(x)
        x = nn.BatchNorm2d(nu[2])(x)
        x = nn.LeakyReLU()(x)
        x = nn.Conv2d(in_channels = ns[2]+nu[3], out_channels=nu[2], kernel_size = 1, stride = 1, padding = 0)(x)
        x = nn.BatchNorm2d(nu[2])(x)
        x = nn.LeakyReLU()(x)
        x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
    
    
        #########
        # i = 1 #
        #########
    
        # Fourth Up
        # self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
        x = torch.cat([x,s1], axis=1) # IMPORTANT HERE
        # Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
        x = nn.BatchNorm2d(in_channel=ns[1]+nu[2])(x)
        x = nn.ReflectionPad2d(padding)(x)
        x = nn.Conv2d(in_channels=ns[1]+nu[2], out_channels=nu[1], kernel_size=ku[1], stride=1, padding=0)(x)
        x = nn.BatchNorm2d(nu[1])(x)
        x = nn.LeakyReLU()(x)
        x = nn.Conv2d(in_channels = ns[1]+nu[2], out_channels=nu[1], kernel_size = 1, stride = 1, padding = 0)(x)
        x = nn.BatchNorm2d(nu[1])(x)
        x = nn.LeakyReLU()(x)
        x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')    
    
    
        #########
        # i = 0 #
        #########
    
        # Fifth Up
        # self.ups[i].forward(torch.cat([x,s[self.length-1-i]],axis = 1))
        x = torch.cat([x,s0], axis=1) # IMPORTANT HERE
        # Model_Up(in_channels = ns[i]+nu[i+1], nu = nu[i], ku = ku[i])
        x = nn.BatchNorm2d(in_channel=ns[0]+nu[1])(x)
        x = nn.ReflectionPad2d(padding)(x)
        x = nn.Conv2d(in_channels=ns[0]+nu[1], out_channels=nu[0], kernel_size=ku[0], stride=1, padding=0)(x)
        x = nn.BatchNorm2d(nu[0])(x)
        x = nn.LeakyReLU()(x)
        x = nn.Conv2d(in_channels = nu[0], out_channels=nu[0], kernel_size = 1, stride = 1, padding = 0)(x)
        x = nn.BatchNorm2d(nu[0])(x)
        x = nn.LeakyReLU()(x)
        x = F.interpolate(x, scale_factor = 2, mode = 'bilinear')
    
    
        ################################
        # OUT ##########################
        ################################
    
        x = nn.Conv2d(in_channels=nu[0], out_channels=out_channels, kernel_size=1, padding = 0)
        return nn.Sigmoid()(x) #Squash to RGB ([0,1]) format