PyTorch中非随机替换的均匀采样

PyTorch中非随机替换的均匀采样,pytorch,Pytorch,给定形状为4x4的布尔张量(遮罩): x = torch.tensor([ [ True, True, False, True ], [ False, False, False, True ], [ False, True, True, True ], [ False, True, True, False ], ]) 我想根据以下4x7形状的结果对(7)进行相应的取样: tensor([[0, 0, 0, 1, 1, 3, 3], [3, 3

给定形状为4x4的布尔张量(遮罩):

x = torch.tensor([ 
    [ True, True, False, True ], 
    [ False, False, False, True ],
    [ False, True, True, True ],
    [ False, True, True, False ],
])
我想根据以下4x7形状的结果对(7)进行相应的取样:

tensor([[0, 0, 0, 1, 1, 3, 3],
        [3, 3, 3, 3, 3, 3, 3],
        [1, 1, 1, 2, 2, 3, 3],
        [1, 1, 1, 1, 2, 2, 2]])
最接近我的是以下实现:

def uniform_sampling(tensor, count = 1):
    indices = torch.arange(0, tensor.shape[-1], device = tensor.device).expand(tensor.shape)
    samples_count = tensor.long().sum(-1)
    output = tensor.long() * (count // samples_count)[:, None]
    remainder = count - output.sum(-1)
    
    rem1 = torch.stack((remainder, tensor.sum(-1) - remainder), -1).flatten()
    rem2 = torch.stack((torch.ones_like(remainder), torch.zeros_like(remainder)), -1).flatten()
    remaining = rem2.repeat_interleave(rem1, 0)
    
    output[tensor > 0] += remaining
    samples = indices[tensor].repeat_interleave(output[tensor], -1).view(-1, count)
    
    return samples

uniform_sampling(x, count = 7)
是否有任何(可能是本地的)Pytork功能可以实现同样的功能,但速度更快、效率更高