I have a batch of segmented images of size
seg
--> [batch, channels, imsize, imgsize]
--> [16, 6, 50, 50]
each scalar in this tensor specifies one of the segmentation classes.
We have 2000
total segmentation classes.
Now the goal is to convert
[16, 6, 50, 50]
--> [16, 2000, 50, 50]
where each class is encoded in one hot fashion.
How do I do it with pytorch api? I can only think of ridiculously inefficient looping construction.
Example
Here we will have only 2 initial channels (instead of 6), 4 labels (instead of 2000), size batch 1 (instead of 16) and 4x4 image instead of 50x50.
0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1
3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2
Now this turns into 4 channel output
1, 1, 0, 0
1, 1, 1, 0
0, 0, 0, 0
0, 0, 0, 0
0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1
1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0
0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1
The key observation is that a particular label appears only on a single input channel.
6
represent? Your goal is to convert 16 6-channel 50x50 images into 16 2000-channel 50x50 images. You going from 15,000 points/image to 5,000,000 points/image. Are you sure this is what you're looking for? – Ivan