Python Pyrotch获取模型的所有层
采用pytorch模型并获得所有层的列表(无任何Python Pyrotch获取模型的所有层,python,pytorch,Python,Pytorch,采用pytorch模型并获得所有层的列表(无任何nn.Sequence分组)的最简单方法是什么?例如,有更好的方法吗 import pretrainedmodels def unwrap_model(model): for i in children(model): if isinstance(i, nn.Sequential): unwrap_model(i) else: l.append(i) model = pretrainedmodels.__
nn.Sequence
分组)的最简单方法是什么?例如,有更好的方法吗
import pretrainedmodels
def unwrap_model(model):
for i in children(model):
if isinstance(i, nn.Sequential): unwrap_model(i)
else: l.append(i)
model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
unwrap_model(model)
print(l)
您可以使用该方法迭代模型的所有模块(包括每个
Sequential
)中的模块。下面是一个简单的例子:
>>> model = nn.Sequential(nn.Linear(2, 2),
nn.ReLU(),
nn.Sequential(nn.Linear(2, 1),
nn.Sigmoid()))
>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]
>>> l
[Linear(in_features=2, out_features=2, bias=True),
ReLU(),
Linear(in_features=2, out_features=1, bias=True),
Sigmoid()]
我是这样做的:
def flatten(el):
flattened = [flatten(children) for children in el.children()]
res = [el]
for c in flattened:
res += c
return res
cnn = nn.Sequential(Custom_block_1, Custom_block_2)
layers = flatten(cnn)
我为一个更深层的模型进行了网络划分,并不是所有的块都来自nn.sequential
def get_children(model: torch.nn.Module):
# get children form model!
children = list(model.children())
flatt_children = []
if children == []:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
return flatt_children
如果您希望以命名的
dict
格式显示图层,这是最简单的方法:
named_layers = dict(model.named_modules())
这将返回类似于:
{
'conv1': <some conv layer>,
'fc1': < some fc layer>,
### and other layers
}
我有一个由几个模块组成的ResNet。这个答案很有效,它是通用的,这是好的。