0
votes

I have a batch of segmented images of size

seg --> [batch, channels, imsize, imgsize] --> [16, 6, 50, 50]

each scalar in this tensor specifies one of the segmentation classes. We have 2000 total segmentation classes.

Now the goal is to convert [16, 6, 50, 50] --> [16, 2000, 50, 50] where each class is encoded in one hot fashion.

How do I do it with pytorch api? I can only think of ridiculously inefficient looping construction.

Example

Here we will have only 2 initial channels (instead of 6), 4 labels (instead of 2000), size batch 1 (instead of 16) and 4x4 image instead of 50x50.

0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1

3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2

Now this turns into 4 channel output

1, 1, 0, 0
1, 1, 1, 0
0, 0, 0, 0
0, 0, 0, 0

0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1

1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0

0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1

The key observation is that a particular label appears only on a single input channel.

1
What does the 6 represent? Your goal is to convert 16 6-channel 50x50 images into 16 2000-channel 50x50 images. You going from 15,000 points/image to 5,000,000 points/image. Are you sure this is what you're looking for?Ivan
@Ivan 6 is just some channels with labels from 0 to 2000. This is not that important detail... on channel 0 only subset of labels can be present , on channel 1 another subset of labels, etc. So channel zero partitions image on labels that correspond to objects, channel 1 partition image on labels that correspond to parts, etc. Now since there is 2000 total labels you owuld need 2000 x imgsize x imgsize for one hot encoding. Does it make senseYohanRoth
Ok, on a given channel, say channel=0, how would you know which labels it corresponds to? I don't clearly see how would you go about converting a 6-channel image into a 2000-channel image of the same size.Ivan
@Ivan it would be given by a label number... Say channel zero has label 10 with some other labels. Now [:, 10, :, :] would have 1 on the same location with 10 and 0 everywhere else. Does it make senseYohanRoth
@Ivan sure, its done!YohanRoth

1 Answers

1
votes

I think you can achieve this without too much trouble. Construct as many masks as there are labels, then stack those masks together, sum on the channel layer and convert to floats:

>>> x
tensor([[[0, 0, 1, 1],
         [0, 0, 0, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[3, 3, 2, 2],
         [3, 3, 2, 2],
         [3, 3, 2, 2],
         [3, 3, 2, 2]]])

>>> y = torch.stack([x==i for i in range(x.max()+1)], dim=1).sum(dim=2)
tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 1., 1.],
         [0., 0., 0., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.]],

        [[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.]]])