Python 是否可以从GCS bucket URL加载预训练的Pytorch模型,而不首先在本地持久化?

Python 是否可以从GCS bucket URL加载预训练的Pytorch模型,而不首先在本地持久化?,python,google-cloud-storage,pytorch,google-cloud-dataflow,Python,Google Cloud Storage,Pytorch,Google Cloud Dataflow,我是在谷歌数据流的背景下问这个问题的,但也是一般性的 使用PyTorch,我可以引用包含多个文件的本地目录,这些文件构成了一个预训练模型。我碰巧在使用一个Roberta模型,但其他模型的界面是相同的 ls some-directory/ added_tokens.json config.json merges.txt pytorch_model.bin special_

我是在谷歌数据流的背景下问这个问题的,但也是一般性的

使用PyTorch,我可以引用包含多个文件的本地目录,这些文件构成了一个预训练模型。我碰巧在使用一个Roberta模型,但其他模型的界面是相同的

ls some-directory/
      added_tokens.json
      config.json             
      merges.txt              
      pytorch_model.bin       
      special_tokens_map.json vocab.json
从pytorch_变压器导入Roberta模型
#这很有效
model=RobertaModel.from_pretrained('/path/to/some directory/'))
但是,我的预训练模型存储在GCS桶中。让我们称之为
gs://my bucket/roberta/

在GoogleDataflow中加载此模型的上下文中,我试图保持无状态并避免持久化到磁盘,因此我的首选是直接从GCS获取此模型。据我所知,来自_pretrained()的PyTorch通用接口方法
可以采用本地目录或URL的字符串表示。但是,我似乎无法从GCS URL加载模型

#此操作失败
model=RobertaModel.from_pretrained('gs://my bucket/roberta/'))
#ValueError:无法将gs://mahmed\u bucket/roberta base解析为URL或本地路径
如果我尝试使用目录blob的公共https URL,它也会失败,尽管这可能是因为python环境中引用的可以创建客户端的凭据不会转换为对
https://storage.googleapis

# this fails, probably due to auth
bucket = gcs_client.get_bucket('my-bucket')
directory_blob = bucket.blob(prefix='roberta')
model = RobertaModel.from_pretrained(directory_blob.public_url)
# ValueError: No JSON object could be decoded

# and for good measure, it also fails if I append a trailing /
model = RobertaModel.from_pretrained(directory_blob.public_url + '/')
# ValueError: No JSON object could be decoded
我理解这一点,它实际上只是一个位于bucket名称下的平面名称空间。然而,我似乎被身份验证的必要性和一个不会说话的PyTorch阻塞了

我可以先在本地持久化文件来解决这个问题

从pytorch_变压器导入Roberta模型
从google.cloud导入存储
导入临时文件
local_dir=tempfile.mkdtemp()
gcs=storage.Client()
bucket=gcs.get\u bucket(bucket\u名称)
blob=bucket.list\u blob(前缀=blob\u前缀)
对于blob中的blob:
blob.download_to_filename(local_dir+'/'+os.path.basename(blob.name))
model=RobertaModel.from_pretrained(本地_dir)
但这看起来像是一个黑客,我一直在想我一定错过了什么。当然有一种方法可以保持无状态,而不必依赖于磁盘持久性

  • 那么,有没有办法加载存储在GCS中的预训练模型
  • 在这种情况下执行公共URL请求时,是否有方法进行身份验证
  • 即使有办法进行身份验证,子目录的不存在是否仍然是一个问题
谢谢你的帮助!我也很高兴被指出任何重复的问题,因为我肯定找不到任何问题


编辑和澄清

  • 我的Python会话已经通过GCS的身份验证,这就是为什么我能够在本地下载blob文件,然后使用
    load\u frompretrained()

  • load\u frompretrained()
    需要目录引用,因为它需要问题顶部列出的所有文件,而不仅仅是
    pytorch model.bin

  • 为了澄清问题#2,我想知道是否有某种方法可以为PyTorch方法提供一个嵌入了加密凭据的请求URL或类似的东西。有点长,但我想确保我没有错过任何东西

  • 为了澄清问题#3(除了下面对一个答案的评论),即使有一种方法可以在URL中嵌入我不知道的凭据,我仍然需要引用一个目录而不是一个blob,我不知道GCS子目录是否会被识别为这样,因为(正如Google docs声明的那样)GCS中的子目录是一种幻觉,它们并不代表真正的目录结构。所以我认为这个问题是不相关的,或者至少被问题2挡住了,但这是我追逐的线索,所以我仍然很好奇


正如您正确指出的那样,开箱即用的
pytorch transformers
似乎不支持这一点,但主要是因为它无法将文件链接识别为URL

经过一些搜索,我在第144-155行附近的中找到了相应的错误消息

当然,您可以尝试将
'gs'
标记添加到第144行,然后将您与地面军事系统的连接解释为
http
请求(第269-272行)。如果地面军事系统接受这一点,这应该是唯一需要改变的事情,以便工作。

如果这不起作用,唯一直接的解决办法是实现类似于Amazon S3 bucket函数的功能,但我对S3和GCS bucket的了解还不够,无法在这里做出任何有意义的判断。

我对Pytorch或Roberta模型知之甚少,但我将尝试回答您关于GCS的询问:

1.-“有没有办法加载存储在地面军事系统中的预训练模型?”

如果您的模型可以直接从二进制文件加载Blob:

from google.cloud import storage

client = storage.Client()
bucket = client.get_bucket("bucket name")
blob = bucket.blob("path_to_blob/blob_name.ext")
data = blob.download_as_string() # you will have your binary data transformed into string here.
2.-“在此上下文中执行公共URL请求时,是否有方法进行身份验证?”

这里是棘手的部分,因为根据您运行脚本的上下文,它将使用默认服务帐户进行身份验证。因此,当您使用官方GCP LIB时,您可以:

A.-授予该默认服务帐户访问您的bucket/对象的权限

B.-创建一个新的服务帐户并在脚本中使用它进行身份验证(您还需要为该服务帐户生成身份验证令牌):

但是,这是可行的,因为官方的libs在后台处理对API调用的身份验证,因此在from_pretrained()函数的情况下不起作用

因此,另一种方法是将对象公开,这样您可以在使用公共url时访问它

3.-“即使有办法进行身份验证,子目录的不存在是否仍然是一个问题?”

不确定你是说这里,你可以
from google.cloud import storage
from google.oauth2 import service_account

VISION_SCOPES = ['https://www.googleapis.com/auth/devstorage']
SERVICE_ACCOUNT_FILE = 'key.json'

cred = service_account.Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE, scopes=VISION_SCOPES)

client = storage.Client(credentials=cred)
bucket = client.get_bucket("bucket_name")
blob = bucket.blob("path/object.ext")
data = blob.download_as_string()
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'your_gcs_auth.json'

# initiate storage
client = storage.Client()
en_bucket = client.get_bucket('your-gcs-bucketname')

# get blob
en_model_blob = en_bucket.get_blob('your-modelname-in-gcsbucket.bin')
en_model = en_model_blob.download_as_string()

# because model downloaded into string, need to convert it back
buffer = io.BytesIO(en_model)

# prepare loading model
state_dict = torch.load(buffer, map_location=torch.device('cpu'))
model = BertForTokenClassification.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=main_config)
model.load_state_dict(state_dict)