Python 如何pickle使用lambda函数的任意pytorch模型?
我目前有一个神经网络模块:Python 如何pickle使用lambda函数的任意pytorch模型?,python,pytorch,pickle,Python,Pytorch,Pickle,我目前有一个神经网络模块: import torch.nn as nn class NN(nn.Module): def __init__(self,args,lambda_f,nn1, loss, opt): super().__init__() self.args = args self.lambda_f = lambda_f self.nn1 = nn1 self.loss = loss
import torch.nn as nn
class NN(nn.Module):
def __init__(self,args,lambda_f,nn1, loss, opt):
super().__init__()
self.args = args
self.lambda_f = lambda_f
self.nn1 = nn1
self.loss = loss
self.opt = opt
# more nn.Params stuff etc...
def forward(self, x):
#some code using fields
return out
我正在尝试检查它,但因为Pytork使用state_dict
s保存,这意味着如果我使用Pytorktorch.save进行检查,我无法保存我实际使用的lambda函数。我真的想保存所有内容而不出问题,然后重新加载到GPU上进行训练。我目前正在使用:
def save_ckpt(path_to_ckpt):
from pathlib import Path
import dill as pickle
## Make dir. Throw no exceptions if it already exists
path_to_ckpt.mkdir(parents=True, exist_ok=True)
ckpt_path_plus_path = path_to_ckpt / Path('db')
## Pickle args
db['crazy_mdl'] = crazy_mdl
with open(ckpt_path_plus_path , 'ab') as db_file:
pickle.dump(db, db_file)
当前,当我选中它并保存它时,它不会抛出任何错误
我担心在培训it时,即使没有培训异常/错误,也可能会出现一个微妙的错误,或者可能会发生意外情况(例如,在集群中的磁盘上进行奇怪的保存等,谁知道呢)
这对pytorch类/nn模型安全吗?特别是如果我们想用GPU恢复训练
交叉张贴:
- ?
我是迪尔的作者。我使用
dill
(和klepot
)保存lambda函数中包含经过训练的ann的类。我倾向于使用mystic
和sklearn
的组合,因此我无法直接与pytorch
对话,但我可以假设它的工作原理相同。需要注意的地方是,如果lambda包含指向lambda外部对象的指针。。。例如y=4;f=λx:x+y
。这可能看起来很明显,但是dill将pickle lambda,并且根据代码的其余部分和序列化变量,可能不会序列化y
的值。因此,我见过很多情况,人们在函数(或lambda或类)中序列化一个经过训练的估计器,然后从序列化中恢复函数时,结果不“正确”。最主要的原因是函数未封装,因此函数生成正确结果所需的所有对象都存储在pickle中。然而,即使在这种情况下,您也可以得到“正确”的结果,但是您只需要创建与处理估计器时相同的环境(即,它在周围名称空间中依赖的所有相同值)。要点应该是,尝试确保函数中使用的所有变量都在函数中定义。这是我最近开始使用的一个类的一部分(应该在下一个版本的mystic
中):
注意:调用函数时,它使用的所有内容(包括np
)都在周围的名称空间中定义。只要pytorch估计器按预期进行序列化(无外部引用),那么遵循上述指导原则就可以了。是,我认为使用dill
来pickle lambda函数等是安全的。我一直在使用torch.save
和dill来保存state dict,并且在通过GPU和CPU恢复训练时没有问题,除非模型类被更改。即使模型类被更改(添加/删除一些参数),我也可以加载state dict,修改它,然后加载到模型中
此外,通常情况下,人们不保存模型对象,而只保存状态dict,即参数值,以恢复训练,以及超参数/模型参数,以便以后获得相同的模型对象
保存模型对象有时会有问题,因为对模型类(代码)的更改会使保存的对象变得无用。如果您根本不打算更改模型类/代码,因此不会更改模型对象,那么保存对象可能会很好,但通常不建议对模块对象进行pickle 好的!所以我猜dill的工作原理和我预期的有点不同。我假设用于lambda函数的dill会将“闭包”(即函数名、主体和程序环境/命名空间)保存为pickle执行期间的状态。您似乎在说,这不是它的工作方式,相反,它所做的是保存名称,然后使用当前本地环境解析lambda函数。对吗?这让我很难过,因为这意味着我不能不用担心就用莳萝来泡菜(虽然我不想听起来忘恩负义,但我确信这是一个很难解决的问题)。你展示的代码片段是不是在试图教我/演示如何做我需要的事情?i、 e.通过在我们pickle的类的字段定义之后保存lambda函数来获得正确的程序env?这样,当它被还原时,它使用它在还原lambda函数时所pickle的数据值。这基本上就是你想告诉我的吗?dill
有几个序列化变体,它们都以不同的方式对待全局名称空间。不保存任何全局名称空间(如pickle
),尝试保存直接引用的名称空间的所有成员(如cloudpickle
),然后保存dill
独有的两个变体——将全局名称空间保存为dict,并通过提取生成的代码保存对象。因此,dill
确实保存了名称空间。我想说的是,你必须帮助它,以确保参考连接到预期的。。。您可以通过如上所述封装变量来实现这一点。我在上面的代码中展示的是一种策略,它允许您编写一个lambda,该lambda使用指向封闭名称空间的指针引用,并且仍将按照预期进行序列化。裸lambda将序列化,但更可能出现的问题是,lambda中的引用未指向预期值。
class Estimator(object):
"a container for a trained estimator and transform (not a pipeline)"
def __init__(self, estimator, transform):
"""a container for a trained estimator and transform
Input:
estimator: a fitted sklearn estimator
transform: a fitted sklearn transform
"""
self.estimator = estimator
self.transform = transform
self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
def __call__(self, *x):
"f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
import numpy as np
return self.function(*x)