0
votes

I'm attempting to plot an array of output data from a U-net. The array contains mnist data that has been one-hot encoded for image segmentation.

it's shape is: (28,28,11)

So, for each place in the original image where the pixel value is 0, the one-hot encoding would place an array of [0 0 0 0 0 0 0 0 0 0 1], indicating that this pixel is blank.

If on the other hand the pixel value is > 0, the one-hot array would show the classification of the overall image.

EX: If the mnist image is of a 2, each pixel where the value is > 0 would be turned into an array [0 0 1 0 0 0 0 0 0 0 0].

I'm wondering if there's a way to display such an array, with each element of the array being made up of a one-hot array.

I've tried to just use plt.imshow on the data, however, it throws an error saying that the "TypeError: Invalid dimensions for image data"

Here is the code I'm working with

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
import skimage.transform
import cv2
import sys
from keras import Input
from keras import backend as K
from keras.utils import np_utils
from keras.models import Sequential, Model 
from keras.utils import to_categorical
from keras.losses import categorical_crossentropy
from keras.optimizers import adam
from keras.layers import Conv2D, Dense, MaxPooling2D, Flatten, Dropout, GlobalAveragePooling2D
from keras.datasets import cifar10
from keras.datasets import mnist
from keras.utils import np_utils
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(threshold=sys.maxsize)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

y_train = y_train[:10]

data = np.random.choice(255, (10,128,128))


## do you calculation of brightness here
## and expand it to one row per pixel
arr = data.reshape(-1,1)/255
## repeat labels to match the expanded pixel
labels = y_train.repeat(128*128).reshape(-1,1)

ind_row = np.arange(len(arr))
ind_col = np.where(arr>0, labels, 10).ravel()

one_hot_coded_arr = np.zeros((len(arr), 11))
one_hot_coded_arr[ind_row,ind_col]=1

## convert back to desired shape
one_hot_coded_arr = one_hot_coded_arr.reshape(-1, 128,128,11)
#print(one_hot_coded_arr[:28,:])
print(one_hot_coded_arr.shape)


plt.imshow(one_hot_coded_arr, interpolation='nearest')
plt.axis("off")
plt.show()

I want to display an image something like this: https://documentation.sas.com/api/docsets/casdlpg/8.4/content/images/mnistout2.png

But I keep running into the error "TypeError: Invalid dimensions for image data"

Any help would be amazing, thanks!

1

1 Answers

0
votes

You have too many dimensions. matplotlib.pyplot only plots 2 dimensions (x, y).

Therefore, you should first choose your image to display i.e. output[n]. Next, as it's one hot encoded, use the np.argmax(output[n], axis=-1) function to 'unencode'.

Alternatively, just select the layer you want output[n,:,:,l].

Let me know if that works.