Pytorch 基于多处理的huggingface-Bert分类

Pytorch 基于多处理的huggingface-Bert分类,pytorch,python-multiprocessing,huggingface-transformers,distilbert,Pytorch,Python Multiprocessing,Huggingface Transformers,Distilbert,我试图使用torch多处理来并行化来自两个独立的huggingface分类模型的预测。它似乎在预测阶段陷入僵局。我正在使用python 3.6.5、torch 1.5.0和huggingface transformers 2.11.0版。 运行代码的输出是 Tree enc done Begin tree prediction<------(Comment: Both begin tree End tree predictions<------- and end tree pred

我试图使用torch多处理来并行化来自两个独立的huggingface分类模型的预测。它似乎在预测阶段陷入僵局。我正在使用python 3.6.5、torch 1.5.0和huggingface transformers 2.11.0版。 运行代码的输出是

Tree enc done
Begin tree prediction<------(Comment: Both begin tree
End tree predictions<-------  and end tree predictions)
0.03125429153442383
Dn prediction
Dn enc done
Begin dn predictions<------(Comment: Both begin dn
End dn predictions<-------  and end dn predictions)
0.029727697372436523
----------Done sequential predictions-------------

--------Start Parallel predictions--------------
Tree prediction
Tree enc done
Begin tree prediction. <------(Comment: Process is deadlocked after this)
Dn prediction
Dn enc done
Begin dn predictions. <-------(Comment: Process is deadlocked after this)
import torch
import torch.multiprocessing as mp
import time
import transformers
from transformers import  DistilBertForSequenceClassification

# Load the BERT tokenizer.
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

tree_model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels = 2,
    output_attentions = False, 
    output_hidden_states = False
)
tree_model.eval()

dn_model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels = 2,
    output_attentions = False, 
    output_hidden_states = False, 
)
dn_model.eval()


tree_model.share_memory()
dn_model.share_memory()


def predict(sentences =[], tokenizer=tokenizer,models=(tree_model,dn_model,None)):
  MAX_SENTENCE_LENGTH = 16
  start = time.time()
  input_ids = []
  attention_masks = []
  predictions = []

  tree_model = models[0]
  dn_model = models[1]

  if models[0]:
      print("Tree prediction")
  if models[1]:
    print("Dn prediction")
  for sent in sentences:
    encoded_dict = tokenizer.encode_plus(
                        sent,                      
                        add_special_tokens = True, 
                        max_length = MAX_SENTENCE_LENGTH,
                        pad_to_max_length = True,
                        return_attention_mask = True,   
                        return_tensors = 'pt',     
                   )

    # Add the encoded sentence to the list.    
    input_ids.append(encoded_dict['input_ids'])

    # And its attention mask (simply differentiates padding from non-padding).
    attention_masks.append(encoded_dict['attention_mask'])

  if tree_model:
      print("Tree enc done")
  if dn_model:
    print("Dn enc done")

  # Convert the lists into tensors.
  new_input_ids = torch.cat(input_ids, dim=0)
  new_attention_masks = torch.cat(attention_masks, dim=0)

  with torch.no_grad():
      # Forward pass, calculate logit predictions
    if tree_model:
      print("Begin tree prediction")
      outputs = tree_model(new_input_ids, 
                      attention_mask=new_attention_masks) 
      print("End tree predictions")
    else:
      print("Begin dn predictions")
      outputs = dn_model(new_input_ids, 
                      attention_mask=new_attention_masks)
      print("End dn predictions")
  logits = outputs[0]
  logits = logits.detach().cpu()
  print(time.time()-start)
  predictions = logits
  return predictions



def get_tree_prediction(sentence, tokenizer=tokenizer,models=(tree_model,dn_model, None)):
    return predict(sentences =[sentence], tokenizer=tokenizer,models=models)

def get_dn_prediction(sentence, tokenizer=tokenizer,models=(tree_model,dn_model, None)):
  return predict(sentences =[sentence], tokenizer=tokenizer,models=models)


if __name__ == '__main__':
    sentence = "hello world"
    processes = []
    get_tree_prediction(sentence, tokenizer, (tree_model,None,None))
    get_dn_prediction(sentence, tokenizer, (None,dn_model,None))
    print("----------Done sequential predictions-------------")

    print('\n--------Start Parallel predictions--------------')
    tr_p = mp.Process(target=get_tree_prediction, args=(sentence, tokenizer,
                                                         (tree_model,None,None)))

    tr_p.start()
    processes.append(tr_p)

    dn_p = mp.Process(target=get_dn_prediction, args=(sentence, tokenizer,
                                                       (None,dn_model,None)))
    dn_p.start()
    processes.append(dn_p)

    for p in processes:
        p.join()
遵循来自的代码。