2
votes

I try to implement a neural network for classifying different defects for quality inspection. I want to use a single-class classification. To accomplish this, I want to train a generative adversarial networks and use the discriminator for classification.

I used the sunflower-example for implementing my first GAN. (https://de.mathworks.com/help/deeplearning/examples/train-generative-adversarial-network.html)

In this example, there is a line which "classifies" the generated outputs with the help of the discriminator network:

dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated); 

Now, I expected the output to consist of 2 labels: "Original" or "Fake". Instead, i get a long list of numbers:

(:,:,1,1) =
    5.9427
(:,:,1,2) =
    7.5930
(:,:,1,3) =
    9.3393
etc.

I think these are the loss-values for the discriminator network.

I would like to know how I can use the resulting discriminator network to classify input images. The problem is that the discriminator network has no fully connected layers, or a classification layer at the end of the layer structure. So it seems like the discriminator architecture differs from the architecture of a "normal" convolutional neural network.

Summary

I want to use the Matlab sunflower-example (https://de.mathworks.com/help/deeplearning/examples/train-generative-adversarial-network.html)for training GANs and extract the discriminator to function as a classification network.

1

1 Answers

0
votes

From the matlab sunflower-example,

dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

it is giving out the output of the last FC connected layer without activation (not the loss). That's why it comes with

probGenerated = sigmoid(dlYPredGenerated);

Therefore, probGenerated is the real output you want to see as fake or real probability. BTW, the output has 4 dimensions since it has fmt label of SSCB (Spatial-spatial-channel-batch), and the loss is from

[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);