Python PyTorch中形状相同的掩蔽张量

Python PyTorch中形状相同的掩蔽张量,python,pytorch,Python,Pytorch,给定相同形状的数组和掩码,我希望相同形状的掩码输出包含0,其中掩码为False 比如说, # input array img = torch.randn(2, 2) print(img) # tensor([[0.4684, 0.8316], # [0.8635, 0.4228]]) print(img.shape) # torch.Size([2, 2]) # mask mask = torch.BoolTensor(2, 2) print(mask) # tensor([[F

给定相同形状的数组和掩码,我希望相同形状的掩码输出包含0,其中掩码为False

比如说,

# input array
img = torch.randn(2, 2)
print(img)
# tensor([[0.4684, 0.8316],
#        [0.8635, 0.4228]])
print(img.shape)
# torch.Size([2, 2])

# mask
mask = torch.BoolTensor(2, 2)
print(mask)
# tensor([[False,  True],
#        [ True,  True]])
print(mask.shape)
# torch.Size([2, 2])

# expected masked output of shape 2x2
# tensor([[0, 0.8316],
#        [0.8635, 0.4228]])
问题:遮罩会更改输出的形状,如下所示:

#1: shape changed
img[mask]
# tensor([0.8316, 0.8635, 0.4228])

最直接的方法是创建另一个张量来处理它

import torch

def generate_masked_tensor(input, mask, fill=0):
    masked_tensor = torch.zeros(input.size()) + fill
    masked_tensor[mask] = input[mask]
    return masked_tensor

if __name__ == "__main__":
    img = torch.randn(2, 2)
    mask = torch.tensor([False, True, True, False]).bool().view(2, 2)
    masked_img = generate_masked_tensor(img, mask)
    print (masked_img)
输出:

tensor([[0.0000, 0.8028],
        [1.5411, 0.0000]])

只需键入将布尔掩码转换为整数掩码,然后键入float,将掩码转换为与
img
中相同的类型。然后执行元素相乘


masked\u output=img*mask.int()

img[mask==False] = 0
或使用

img[~mask] = 0
它将更改
img
本身