Python 基于BERT的NER模型在反序列化时给出不一致的预测

Python 基于BERT的NER模型在反序列化时给出不一致的预测,python,pytorch,bert-language-model,huggingface-transformers,Python,Pytorch,Bert Language Model,Huggingface Transformers,我正在尝试使用Colab cloud GPU上的HuggingFace transformers库训练一个NER模型,对其进行pickle处理,然后将模型加载到我自己的CPU上进行预测 代码 模型如下所示: from transformers import BertForTokenClassification model = BertForTokenClassification.from_pretrained( "bert-base-cased", num_

我正在尝试使用Colab cloud GPU上的HuggingFace transformers库训练一个NER模型,对其进行pickle处理,然后将模型加载到我自己的CPU上进行预测

代码

模型如下所示:

from transformers import BertForTokenClassification

model = BertForTokenClassification.from_pretrained(
    "bert-base-cased",
    num_labels=NUM_LABELS,
    output_attentions = False,
    output_hidden_states = False
)
我使用这个片段在Colab上保存模型

import torch

torch.save(model.state_dict(), FILENAME)
然后使用

# Initiating an instance of the model type

model_reload = BertForTokenClassification.from_pretrained(
    "bert-base-cased",
    num_labels=len(tag2idx),
    output_attentions = False,
    output_hidden_states = False
)

# Loading the model
model_reload.load_state_dict(torch.load(FILENAME, map_location='cpu'))
model_reload.eval()

用于标记文本和进行实际预测的代码片段在Colab GPU笔记本实例和我的CPU笔记本实例上都是相同的

预期行为

GPU训练的模型行为正确,并对以下标记进行了完美分类:

O       [CLS]
O       Good
O       morning
O       ,
O       my
O       name
O       is
B-per   John
I-per   Kennedy
O       and
O       I
O       am
O       working
O       at
B-org   Apple
O       in
O       the
O       headquarters
O       of
B-geo   Cupertino
O       [SEP]
实际行为

加载模型并使用它在我的CPU上进行预测时,预测完全错误:

I-eve   [CLS]
I-eve   Good
I-eve   morning
I-eve   ,
I-eve   my
I-eve   name
I-eve   is
I-geo   John
B-eve   Kennedy
I-eve   and
I-eve   I
I-eve   am
I-eve   working
I-eve   at
I-gpe   Apple
I-eve   in
I-eve   the
I-eve   headquarters
I-eve   of
B-org   Cupertino
I-eve   [SEP]

有人知道为什么它不起作用吗?我错过了什么吗?

我修复了它,有两个问题:

  • 令牌的索引标签映射是错误的,出于某种原因,list()函数在Colab GPU上的工作方式与我的CPU(?)不同

  • 用于保存模型的代码段不正确,对于基于huggingface transformers库的模型,您不能使用模型。保存\u dict()并稍后加载,您需要使用模型类的save\u pretrained()方法,稍后使用from\u pretrained()加载它


  • 您能否与我们分享您的
    状态_dict
    。我想你应该打开一个bug报告。