2
votes

I am trying to build two neural network for classification. One for Binary and the second is for multi-class classification. I am trying to use the torch.nn.CrossEntropyLoss() as a loss function, but I try to train my first neural network I get the following error:

multi-target not supported at /opt/conda/conda-bld/pytorch_1565272271120/work/aten/src/THNN/generic/ClassNLLCriterion.c:22

From my analysis, I found that the my dataset has two problems that caused the error.

  • My data set is one hot encoded. I used one hot encoding to pre processes my dataset. The first target Y_binary variable has the shape of torch.Size([125973, 1]) full of 0s and 1 indicating classes 'No' and 'Yes'.
  • My data has the wrong dimensions? I found that I can't use a simple vector with the cross entropy loss function. Some people used the following code to reshape their target vector before feeding to the loss function.

out = out.permute(0, 2, 3, 1).contiguous().view(-1, class_number)

But I didn't really understand the reasoning behind this code. But it seems for my that I need to keep track of the following variables: Class_Number, Batch_size, Dimension_Output. For my code here are the dimensions

X_train.shape: (125973, 122)
Y_train2.shape: (125973, 1)
batch_size = 64
K = len(set(Y_train2)) # Binary classification For multi class classification use K = len(set(Y_train5))
  • Should the target value be one hot encoded? If not, how I can feed a nominal feature to the loss function?
  • If I use reshape the output, can you help me do this for my code ?

I am trying to use this loss function for both my neural networks.

Thank you in advance,

1

1 Answers

1
votes

The error is due to the usage of torch.nn.CrossEntropyLoss() which can be used if you want to predict 1 class out of N classes. For multiclass classification, you should use torch.nn.BCEWithLogitsLoss() which combines a Sigmoid layer and the BCELoss in one single class.

In case of multi-class, and if you use Sigmoid + BCELoss, then you need the target to be one-hot encoding, i.e. something like this per sample: [0 1 0 0 0 1 0 0 1 0], where 1 will be at the locations of classes present.