Python 分段到一个热编码

Python 分段到一个热编码,python,numpy,pytorch,Python,Numpy,Pytorch,我有一批大小不同的分割图像 seg-->[批次、频道、imsize、imgsize]-->[16,6,50,50] 该张量中的每个标量指定一个分段类。 我们有2000个total segmentation类 现在的目标是转化 [16,6,50,50]-->[16,2000,50,50] 其中每个类都以一种热门方式编码 我如何使用pytorch api? 我只能想到效率低得可笑的循环构造 范例 在这里,我们只有2个初始通道(而不是6个)、4个标签(而不是2000个)、1号批次(而不是16个)和4x

我有一批大小不同的分割图像

seg
-->
[批次、频道、imsize、imgsize]
-->
[16,6,50,50]

该张量中的每个标量指定一个分段类。 我们有
2000个
total segmentation类

现在的目标是转化
[16,6,50,50]
-->
[16,2000,50,50]
其中每个类都以一种热门方式编码

我如何使用pytorch api? 我只能想到效率低得可笑的循环构造

范例

在这里,我们只有2个初始通道(而不是6个)、4个标签(而不是2000个)、1号批次(而不是16个)和4x4图像(而不是50x50个)

0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1

3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2
现在变成4通道输出

1, 1, 0, 0
1, 1, 1, 0
0, 0, 0, 0
0, 0, 0, 0

0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1

1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0

0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1

关键的观察结果是,一个特定的标签只出现在一个输入通道上。

我认为您可以轻松实现这一点。构造尽可能多的遮罩,然后将这些遮罩堆叠在一起,在通道层上求和并转换为浮点数:

>>> x
tensor([[[0, 0, 1, 1],
         [0, 0, 0, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[3, 3, 2, 2],
         [3, 3, 2, 2],
         [3, 3, 2, 2],
         [3, 3, 2, 2]]])

>>> y = torch.stack([x==i for i in range(x.max()+1)], dim=1).sum(dim=2)
tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 1., 1.],
         [0., 0., 0., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.]],

        [[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.]]])

6
代表什么?您的目标是将16个6通道50x50图像转换为16个2000通道50x50图像。您将从15000点/张图片增加到5000000点/张图片。你确定这就是你要找的吗?@Ivan 6只是一些标签从0到2000的频道。这不是那么重要的细节。。。在通道0上只能显示标签子集,在通道1上可以显示另一个标签子集,等等。因此通道0将图像分割到对应于对象的标签上,通道1将图像分割到对应于零件的标签上,等等。现在,由于总共有2000个标签,一个热编码需要2000 x imgsize x imgsize。它有意义吗?在一个给定的通道上,比如说通道=0,你怎么知道它对应于哪个标签?我不清楚你会如何将一个6通道图像转换成同样大小的2000通道图像。@Ivan它将由一个标签号给出。。。假设通道0有标签10和一些其他标签。现在[:,10,:,:]将在同一位置上有1,在其他位置有10和0。它能使你快乐吗sense@Ivan当然,完成了!这就是我在torch中得到的。大小([16,6,50,50])在torch中得到的。大小([2000,6,50,50])因此有些不对劲…显然没有考虑批处理维度,至少是这样。另一个问题--它仍然使用循环而不是pytorch API…如果处理大数据,速度太慢了。现在应该解决这个问题,这是一个在正确轴上叠加/求和的问题。你必须意识到这不是一项直接的任务,你正在试图从多个通道中解开不确定数量的标签。。。仅使用张量可能无法实现这一点。