Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/311.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/video/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何在不影响0权重的情况下计算触发器和参数?_Python_Deep Learning_Pytorch_Pruning_Flops - Fatal编程技术网

Python 如何在不影响0权重的情况下计算触发器和参数?

Python 如何在不影响0权重的情况下计算触发器和参数?,python,deep-learning,pytorch,pruning,flops,Python,Deep Learning,Pytorch,Pruning,Flops,我的剪枝代码如下所示,运行此代码后,我将得到一个名为“pruned_model.pth”的文件 import torch from torch import nn import torch.nn.utils.prune as prune import torch.nn.functional as F from cnn import net ori_model = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/model.pth

我的剪枝代码如下所示,运行此代码后,我将得到一个名为“pruned_model.pth”的文件

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from cnn import net

ori_model = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/model.pth'
save_path = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/pruned_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = net().to(device)
model.load_state_dict(torch.load(ori_model))  

module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))

prune.l1_unstructured(module, name="weight", amount=0.3)
prune.l1_unstructured(module, name="bias", amount=3)
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.bias)
print(module.weight)
print(module._forward_pre_hooks)
prune.remove(module, 'weight')
prune.remove(module, 'bias')
print(list(module.named_parameters()))
print(model.state_dict())
torch.save(model.state_dict(), save_path)
结果是:

[('weight', Parameter containing:
tensor([[[-0.0000, -0.3137, -0.3221,  ...,  0.5055,  0.3614, -0.0000]],

        [[ 0.8889,  0.2697, -0.3400,  ...,  0.8546,  0.2311, -0.0000]],

        [[-0.2649, -0.1566, -0.0000,  ...,  0.0000,  0.0000,  0.3855]],

        ...,

        [[-0.2836, -0.0000,  0.2155,  ..., -0.8894, -0.7676, -0.6271]],

        [[-0.7908, -0.6732, -0.5024,  ...,  0.2011,  0.4627,  1.0227]],

        [[ 0.4433,  0.5048,  0.7685,  ..., -1.0530, -0.8908, -0.4799]]],
       device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.7497, -1.3594, -1.7613, -2.0137, -1.1763,  0.4150, -1.6996, -1.5354,
         0.4330, -0.9259,  0.4156, -2.3099, -0.4282, -0.5199,  0.1188, -1.1725,
        -0.9064, -1.6639, -1.5834, -0.3655, -2.0727, -2.1078, -1.6431, -0.0694,
        -0.5435, -1.9623,  0.5481, -0.8255, -1.5108, -0.4029, -1.9759,  0.0522,
         0.0599, -2.2469, -0.5599,  0.1039, -0.4472, -1.1706, -0.0398, -1.9441,
        -1.5310, -0.0837, -1.3250, -0.2098, -0.1919,  0.4600, -0.8268, -1.0041,
        -0.8168, -0.8701,  0.3869,  0.1706, -0.0226, -1.2711, -0.9302, -2.0696,
        -1.1838,  0.4497, -1.1426,  0.0772, -2.4356, -0.3138,  0.6297,  0.2022,
        -0.4024,  0.0000, -1.2337,  0.2840,  0.4515,  0.2999,  0.0273,  0.0374,
         0.1325, -0.4890, -2.3845, -1.9663,  0.2108, -0.1144,  0.0544, -0.2629,
         0.0393, -0.6728, -0.9645,  0.3118, -0.5142, -0.4097, -0.0000, -1.5142,
        -1.2798,  0.2871, -2.0122, -0.9346, -0.4931, -1.4895, -1.1401, -0.8823,
         0.2210,  0.4282,  0.1685, -1.8876, -0.7459,  0.2505, -0.6315,  0.3827,
        -0.3348,  0.1862,  0.0806, -2.0277,  0.2068,  0.3281, -1.8045, -0.0000,
        -2.2377, -1.9742, -0.5164, -0.0660,  0.8392,  0.5863, -0.7301,  0.0778,
         0.1611,  0.0260,  0.3183, -0.9097, -1.6152,  0.4712, -0.2378, -0.4972],
       device='cuda:0', requires_grad=True))]
存在许多零权重。在不计算与这些零值相关的计算的情况下,如何计算触发器和参数

我使用下面的代码来计算FLOPs和Params

import torch
from cnn import net
from ptflops import get_model_complexity_info

ori_model = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/model.pth'
pthfile = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/pruned_model.pth'

model = net()
# model.load_state_dict(torch.load(ori_model))  
model.load_state_dict(torch.load(pthfile))  
# print(model.state_dict())

macs, params = get_model_complexity_info(model, (1, 260), as_strings=False,
                                         print_per_layer_stat=True, verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

您可以做的一件事是从FLOPs计算中排除低于某个阈值的权重。为此,您必须修改触发器计数器功能

下面我将提供修改fc和conv层的示例

def linear_flops_counter_hook(module, input, output):
    input = input[0]
    output_last_dim = output.shape[-1]  # pytorch checks dimensions, so here we don't care much
    # MODIFICATION HAPPENS HERE
    num_zero_weights = (module.weight.data.abs() < 1e-9).sum()
    zero_weights_factor = 1 - torch.true_divide(num_zero_weights, module.weight.data.numel())
    module.__flops__ += int(np.prod(input.shape) * output_last_dim) * zero_weights_factor.numpy()
    # MODIFICATION HAPPENS HERE
def线性触发器计数器挂钩(模块、输入、输出):
输入=输入[0]
output_last_dim=output.shape[-1]35; pytorch检查尺寸,所以这里我们不太关心
#修改发生在这里
num_zero_weights=(module.weight.data.abs()<1e-9).sum()
zero\u weights\u factor=1-torch.true\u divide(num\u zero\u weights,module.weight.data.numel())
模块.\uuuuu触发器\uuuuuu+=int(np.prod(input.shape)*输出\u last\u dim)*零权重\u factor.numpy()
#修改发生在这里
def conv_触发器计数器挂钩(conv_模块,输入,输出):
#可以有多个输入,获取第一个输入
输入=输入[0]
批次大小=输入。形状[0]
output_dims=list(output.shape[2:]
kernel\u dims=list(conv\u module.kernel\u size)
in_通道=conv_模块in_通道
输出信道=转换模块。输出信道
groups=conv_module.groups
每个通道的过滤器=输出通道//组
conv_per_position_flops=int(np.prod(kernel_dims))*in_channels*filters_per_channel
活动元素计数=批次大小*整数(np.prod(输出尺寸))
#修改发生在这里
num_zero_weights=(conv_module.weight.data.abs()<1e-9).sum()
zero\u weights\u factor=1-torch.true\u divide(num\u zero\u weights,conv\u module.weight.data.numel())
整体翻转=翻转每位置翻转*活动元素数*零权重系数。numpy()
#修改发生在这里
偏置\u触发器=0
如果conv_module.bias不是无:
偏置触发=输出通道*活动元素计数
整体翻转=整体翻转+偏置翻转
conv_模块._触发器_+=int(总触发器)

请注意,我使用1e-9作为权重计数为零的阈值。

我是否正确理解,如果未执行操作,您希望估计计算复杂性?是的,因为在pruing之后,将有许多0不应计入触发器和参数的数量。通过这种方式,我可以看到修剪的改进。
def linear_flops_counter_hook(module, input, output):
    input = input[0]
    output_last_dim = output.shape[-1]  # pytorch checks dimensions, so here we don't care much
    # MODIFICATION HAPPENS HERE
    num_zero_weights = (module.weight.data.abs() < 1e-9).sum()
    zero_weights_factor = 1 - torch.true_divide(num_zero_weights, module.weight.data.numel())
    module.__flops__ += int(np.prod(input.shape) * output_last_dim) * zero_weights_factor.numpy()
    # MODIFICATION HAPPENS HERE
def conv_flops_counter_hook(conv_module, input, output):
    # Can have multiple inputs, getting the first one
    input = input[0]

    batch_size = input.shape[0]
    output_dims = list(output.shape[2:])

    kernel_dims = list(conv_module.kernel_size)
    in_channels = conv_module.in_channels
    out_channels = conv_module.out_channels
    groups = conv_module.groups

    filters_per_channel = out_channels // groups
    conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel

    active_elements_count = batch_size * int(np.prod(output_dims))

    # MODIFICATION HAPPENS HERE
    num_zero_weights = (conv_module.weight.data.abs() < 1e-9).sum()
    zero_weights_factor = 1 - torch.true_divide(num_zero_weights, conv_module.weight.data.numel())
    overall_conv_flops = conv_per_position_flops * active_elements_count * zero_weights_factor.numpy()
    # MODIFICATION HAPPENS HERE
    
    bias_flops = 0

    if conv_module.bias is not None:

        bias_flops = out_channels * active_elements_count

    overall_flops = overall_conv_flops + bias_flops

    conv_module.__flops__ += int(overall_flops)