Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/310.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python Pyrotch获取模型的所有层_Python_Pytorch - Fatal编程技术网

Python Pyrotch获取模型的所有层

Python Pyrotch获取模型的所有层,python,pytorch,Python,Pytorch,采用pytorch模型并获得所有层的列表(无任何nn.Sequence分组)的最简单方法是什么?例如,有更好的方法吗 import pretrainedmodels def unwrap_model(model): for i in children(model): if isinstance(i, nn.Sequential): unwrap_model(i) else: l.append(i) model = pretrainedmodels.__

采用pytorch模型并获得所有层的列表(无任何
nn.Sequence
分组)的最简单方法是什么?例如,有更好的方法吗

import pretrainedmodels

def unwrap_model(model):
    for i in children(model):
        if isinstance(i, nn.Sequential): unwrap_model(i)
        else: l.append(i)

model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
unwrap_model(model)            
            
print(l)
    

您可以使用该方法迭代模型的所有模块(包括每个
Sequential
)中的模块。下面是一个简单的例子:

>>> model = nn.Sequential(nn.Linear(2, 2), 
                          nn.ReLU(),
                          nn.Sequential(nn.Linear(2, 1),
                          nn.Sigmoid()))

>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]

>>> l

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Linear(in_features=2, out_features=1, bias=True),
 Sigmoid()]
我是这样做的:

def flatten(el):
    flattened = [flatten(children) for children in el.children()]
    res = [el]
    for c in flattened:
        res += c
    return res

cnn = nn.Sequential(Custom_block_1, Custom_block_2)
layers = flatten(cnn)

我为一个更深层的模型进行了网络划分,并不是所有的块都来自nn.sequential

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children

如果您希望以命名的
dict
格式显示图层,这是最简单的方法:

named_layers = dict(model.named_modules())
这将返回类似于:

{
    'conv1': <some conv layer>,
    'fc1': < some fc layer>,
     ### and other layers 
}

我有一个由几个模块组成的ResNet。这个答案很有效,它是通用的,这是好的。