Pytorch 带填充掩模的变换编码器

Pytorch 带填充掩模的变换编码器,pytorch,transformer,attention-model,Pytorch,Transformer,Attention Model,我正在尝试实现torch.nn.TransformerEncoder,其src_key_padding_mask不等于none。假设输入是形状src=[20,95],二进制填充掩码的形状src\u掩码=[20,95],1位于填充标记的位置,0表示其他位置。我制作了一个8层的transformer编码器,每个层包含一个带有8个头部和隐藏尺寸256的注意: layer=torch.nn.TransformerEncoderLayer(256, 8, 256, 0.1) encoder=torch.n

我正在尝试实现torch.nn.TransformerEncoder,其src_key_padding_mask不等于none。假设输入是形状
src=[20,95]
,二进制填充掩码的形状
src\u掩码=[20,95]
,1位于填充标记的位置,0表示其他位置。我制作了一个8层的transformer编码器,每个层包含一个带有8个头部和隐藏尺寸256的注意:

layer=torch.nn.TransformerEncoderLayer(256, 8, 256, 0.1)
encoder=torch.nn.TransformerEncoder(layer, 6)
embed=torch.nn.Embedding(80000, 256)
src=torch.randint(0, 1000, (20, 95))
src = emb(src)
src_mask = torch.randint(0,2,(20, 95))
output =  encoder(src, src_mask)
但我得到了以下错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-107-31bf7ab8384b> in <module>
----> 1 output =  encoder(src, src_mask)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, mask, src_key_padding_mask)
    165         for i in range(self.num_layers):
    166             output = self.layers[i](output, src_mask=mask,
--> 167                                     src_key_padding_mask=src_key_padding_mask)
    168 
    169         if self.norm:

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, src_mask, src_key_padding_mask)
    264         """
    265         src2 = self.self_attn(src, src, src, attn_mask=src_mask,
--> 266                               key_padding_mask=src_key_padding_mask)[0]
    267         src = src + self.dropout1(src2)
    268         src = self.norm1(src)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask)
    781                 training=self.training,
    782                 key_padding_mask=key_padding_mask, need_weights=need_weights,
--> 783                 attn_mask=attn_mask)
    784 
    785 

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
   3250     if attn_mask is not None:
   3251         attn_mask = attn_mask.unsqueeze(0)
-> 3252         attn_output_weights += attn_mask
   3253 
   3254     if key_padding_mask is not None:

RuntimeError: The size of tensor a (20) must match the size of tensor b (95) at non-singleton dimension 2
---------------------------------------------------------------------------
运行时错误回溯(上次最近调用)
在里面
---->1输出=编码器(src、src\U掩码)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py在调用中(self,*input,**kwargs)
545结果=self.\u slow\u forward(*输入,**kwargs)
546其他:
-->547结果=自我转发(*输入,**kwargs)
548用于钩住自身。\u向前\u钩住.values():
549钩子结果=钩子(自身、输入、结果)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py前进(self、src、mask、src\u key\u padding\u mask)
165适用于范围内的i(self.num_层):
166输出=自身层[i](输出,src_掩码=掩码,
-->167 src_key_padding_mask=src_key_padding_mask)
168
169如果自我规范:
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py在调用中(self,*input,**kwargs)
545结果=self.\u slow\u forward(*输入,**kwargs)
546其他:
-->547结果=自我转发(*输入,**kwargs)
548用于钩住自身。\u向前\u钩住.values():
549钩子结果=钩子(自身、输入、结果)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py正向(self、src、src\u掩码、src\u key\u padding\u掩码)
264         """
265 src2=self.self_attn(src,src,src,attn_mask=src_mask,
-->266键填充掩码=src键填充掩码[0]
267 src=src+self.dropout1(src2)
268 src=自身规范1(src)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py在调用中(self,*input,**kwargs)
545结果=self.\u slow\u forward(*输入,**kwargs)
546其他:
-->547结果=自我转发(*输入,**kwargs)
548用于钩住自身。\u向前\u钩住.values():
549钩子结果=钩子(自身、输入、结果)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/activation.py前进(self、query、key、value、key\u padding\u mask、need\u weights、attn\u mask)
781培训=自我培训,
782键填充掩码=键填充掩码,需要权重=需要权重,
-->783附件屏蔽=附件屏蔽)
784
785
~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py在multi\u head\u attention\u forward中(查询、键、值、嵌入尺寸到检查、数字头、项目内权重、项目内偏差、偏差、偏差、添加零附件、退出附件、项目外权重、项目外偏差、培训、关键填充掩码、需要权重、附件掩码、使用单独项目权重、q项目权重、k项目权重、项目外权重、静态、静态)
3250如果附件掩码不是无:
3251附件掩码=附件掩码。未查询(0)
->3252附件输出附件权重+=附件屏蔽
3253
3254如果“键填充”掩码不是“无”:
RuntimeError:张量a(20)的大小必须与张量b(95)在非单态维度2的大小相匹配
我想知道是否有人能帮我解决这个问题


感谢

中显示了所需的形状(变压器的所有构建块都参考它)。编码器的相关形状如下:

  • src:(南、北、东)
  • src_面罩:(S,S)
  • src_键_填充_掩码:(N,S)
其中S是序列长度,N是批量大小,E是嵌入维度(特征数量)

填充掩码的形状应为[95,20],而不是[20,95]。这假设批大小为95,序列长度为20,但如果是相反的情况,则必须将
src
转置

此外,调用编码器时,您没有指定
src\u key\u padding\u mask
,而是指定
src\u mask
,作为的签名是:

forward(src,mask=None,src\u key\u padding\u mask=None)
填充掩码必须指定为关键字参数
src\u key\u padding\u mask
而不是第二个位置参数。为避免混淆,您的
src\u mask
应重命名为
src\u key\u padding\u mask

src_key_padding_mask=torch.randint(0,2,(95,20))
输出=编码器(src,src\U键\U填充\U掩码=src\U键\U填充\U掩码)