Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/362.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何pickle使用lambda函数的任意pytorch模型?_Python_Pytorch_Pickle - Fatal编程技术网

Python 如何pickle使用lambda函数的任意pytorch模型?

Python 如何pickle使用lambda函数的任意pytorch模型?,python,pytorch,pickle,Python,Pytorch,Pickle,我目前有一个神经网络模块: import torch.nn as nn class NN(nn.Module): def __init__(self,args,lambda_f,nn1, loss, opt): super().__init__() self.args = args self.lambda_f = lambda_f self.nn1 = nn1 self.loss = loss

我目前有一个神经网络模块:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out
我正在尝试检查它,但因为Pytork使用
state_dict
s保存,这意味着如果我使用Pytork
torch.save进行检查,我无法保存我实际使用的lambda函数。我真的想保存所有内容而不出问题,然后重新加载到GPU上进行训练。我目前正在使用:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)
当前,当我选中它并保存它时,它不会抛出任何错误

我担心在培训it时,即使没有培训异常/错误,也可能会出现一个微妙的错误,或者可能会发生意外情况(例如,在集群中的磁盘上进行奇怪的保存等,谁知道呢)

这对pytorch类/nn模型安全吗?特别是如果我们想用GPU恢复训练

交叉张贴:

  • ?

    • 我是迪尔的作者。我使用
      dill
      (和
      klepot
      )保存lambda函数中包含经过训练的ann的类。我倾向于使用
      mystic
      sklearn
      的组合,因此我无法直接与
      pytorch
      对话,但我可以假设它的工作原理相同。需要注意的地方是,如果lambda包含指向lambda外部对象的指针。。。例如
      y=4;f=λx:x+y
      。这可能看起来很明显,但是dill将pickle lambda,并且根据代码的其余部分和序列化变量,可能不会序列化
      y
      的值。因此,我见过很多情况,人们在函数(或lambda或类)中序列化一个经过训练的估计器,然后从序列化中恢复函数时,结果不“正确”。最主要的原因是函数未封装,因此函数生成正确结果所需的所有对象都存储在pickle中。然而,即使在这种情况下,您也可以得到“正确”的结果,但是您只需要创建与处理估计器时相同的环境(即,它在周围名称空间中依赖的所有相同值)。要点应该是,尝试确保函数中使用的所有变量都在函数中定义。这是我最近开始使用的一个类的一部分(应该在下一个版本的
      mystic
      中):


      注意:调用函数时,它使用的所有内容(包括
      np
      )都在周围的名称空间中定义。只要pytorch估计器按预期进行序列化(无外部引用),那么遵循上述指导原则就可以了。

      是,我认为使用
      dill
      来pickle lambda函数等是安全的。我一直在使用
      torch.save
      和dill来保存state dict,并且在通过GPU和CPU恢复训练时没有问题,除非模型类被更改。即使模型类被更改(添加/删除一些参数),我也可以加载state dict,修改它,然后加载到模型中

      此外,通常情况下,人们不保存模型对象,而只保存状态dict,即参数值,以恢复训练,以及超参数/模型参数,以便以后获得相同的模型对象


      保存模型对象有时会有问题,因为对模型类(代码)的更改会使保存的对象变得无用。如果您根本不打算更改模型类/代码,因此不会更改模型对象,那么保存对象可能会很好,但通常不建议对模块对象进行pickle

      好的!所以我猜dill的工作原理和我预期的有点不同。我假设用于lambda函数的dill会将“闭包”(即函数名、主体和程序环境/命名空间)保存为pickle执行期间的状态。您似乎在说,这不是它的工作方式,相反,它所做的是保存名称,然后使用当前本地环境解析lambda函数。对吗?这让我很难过,因为这意味着我不能不用担心就用莳萝来泡菜(虽然我不想听起来忘恩负义,但我确信这是一个很难解决的问题)。你展示的代码片段是不是在试图教我/演示如何做我需要的事情?i、 e.通过在我们pickle的类的字段定义之后保存lambda函数来获得正确的程序env?这样,当它被还原时,它使用它在还原lambda函数时所pickle的数据值。这基本上就是你想告诉我的吗?
      dill
      有几个序列化变体,它们都以不同的方式对待全局名称空间。不保存任何全局名称空间(如
      pickle
      ),尝试保存直接引用的名称空间的所有成员(如
      cloudpickle
      ),然后保存
      dill
      独有的两个变体——将全局名称空间保存为dict,并通过提取生成的代码保存对象。因此,
      dill
      确实保存了名称空间。我想说的是,你必须帮助它,以确保参考连接到预期的。。。您可以通过如上所述封装变量来实现这一点。我在上面的代码中展示的是一种策略,它允许您编写一个lambda,该lambda使用指向封闭名称空间的指针引用,并且仍将按照预期进行序列化。裸lambda将序列化,但更可能出现的问题是,lambda中的引用未指向预期值。
      class Estimator(object):
          "a container for a trained estimator and transform (not a pipeline)"
          def __init__(self, estimator, transform):
              """a container for a trained estimator and transform
      
          Input:
              estimator: a fitted sklearn estimator
              transform: a fitted sklearn transform
              """
              self.estimator = estimator
              self.transform = transform
              self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
          def __call__(self, *x):
              "f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
              import numpy as np
              return self.function(*x)