43
votes

I'm a bit confused by the cross entropy loss in PyTorch.

Considering this example:

import torch
import torch.nn as nn
from torch.autograd import Variable

output = Variable(torch.FloatTensor([0,0,0,1])).view(1, -1)
target = Variable(torch.LongTensor([3]))

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss)

I would expect the loss to be 0. But I get:

Variable containing:
 0.7437
[torch.FloatTensor of size 1]

As far as I know cross entropy can be calculated like this:

enter image description here

But shouldn't be the result then 1*log(1) = 0 ?

I tried different inputs like one-hot encodings, but this doesn't work at all, so it seems the input shape of the loss function is okay.

I would be really grateful if someone could help me out and tell me where my mistake is.

Thanks in advance!

5

5 Answers

75
votes

In your example you are treating output [0, 0, 0, 1] as probabilities as required by the mathematical definition of cross entropy. But PyTorch treats them as outputs, that don’t need to sum to 1, and need to be first converted into probabilities for which it uses the softmax function.

So H(p, q) becomes:

H(p, softmax(output))

Translating the output [0, 0, 0, 1] into probabilities:

softmax([0, 0, 0, 1]) = [0.1749, 0.1749, 0.1749, 0.4754]

whence:

-log(0.4754) = 0.7437
25
votes

Your understanding is correct but pytorch doesn't compute cross entropy in that way. Pytorch uses the following formula.

loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
               = -x[class] + log(\sum_j exp(x[j]))

Since, in your scenario, x = [0, 0, 0, 1] and class = 3, if you evaluate the above expression, you would get:

loss(x, class) = -1 + log(exp(0) + exp(0) + exp(0) + exp(1))
               = 0.7437

Pytorch considers natural logarithm.

11
votes

I would like to add an important note, as this often leads to confusion.

Softmax is not a loss function, nor is it really an activation function. It has a very specific task: It is used for multi-class classification to normalize the scores for the given classes. By doing so we get probabilities for each class that sum up to 1.

Softmax is combined with Cross-Entropy-Loss to calculate the loss of a model.

Unfortunately, because this combination is so common, it is often abbreviated. Some are using the term Softmax-Loss, whereas PyTorch calls it only Cross-Entropy-Loss.

1
votes

The combination of nn.LogSoftmax and nn.NLLLoss is equivalent to using nn.CrossEntropyLoss. This terminology is a particularity of PyTorch, as the nn.NLLoss [sic] computes, in fact, the cross entropy but with log probability predictions as inputs where nn.CrossEntropyLoss takes scores (sometimes called logits). Technically, nn.NLLLoss is the cross entropy between the Dirac distribution, putting all mass on the target, and the predicted distribution given by the log probability inputs.

PyTorch's CrossEntropyLoss expects unbounded scores (interpretable as logits / log-odds) as input, not probabilities (as the CE is traditionally defined).

1
votes

Here I give the full formula to manually compute pytorch's CrossEntropyLoss. There is a little precision problem you will see later; do post an answer if you know the exact reason.

First, understand how NLLLoss works. Then CrossEntropyLoss is very similar, except it is NLLLoss with Softmax inside.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def compute_nllloss_manual(x,y0):
    """
    x is the vector with shape (batch_size,C) 
    Note: official example uses log softmax(some vector) as x, so it becomes CELoss.
    y0 shape is the same (batch_size), whose entries are integers from 0 to C-1
    Furthermore, for C>1 classes, the other classes are ignored (see below

    """
    loss = 0.
    n_batch, n_class = x.shape
    # print(n_class)
    for x1,y1 in zip(x,y0):
        class_index = int(y1.item())
        loss = loss + x1[class_index] # other class terms, ignore.
    loss = - loss/n_batch
    return loss

We see from the formula that it is NOT like the standard prescribed NLLLoss because the "other class" terms are ignored (see the comment in the code). Also, remember that Pytorch often processes things in batches. In the following code, we randomly initiate 1000 batches to verify that the formula is correct up to 15 decimal places.

torch.manual_seed(0)
precision = 15

batch_size=10
C = 10

N_iter = 1000
n_correct_nll = 0

criterion = nn.NLLLoss()
for i in range(N_iter):
    x = torch.rand(size=(batch_size,C)).to(torch.float)
    y0 = torch.randint(0,C,size=(batch_size,))

    nll_loss = criterion(x,y0)
    manual_nll_loss = compute_nllloss_manual(x,y0)
    if i==0:
        print('NLLLoss:')
        print('module:%s'%(str(nll_loss)))
        print('manual:%s'%(str(manual_nll_loss)))

    nll_loss_check = np.abs((nll_loss- manual_nll_loss).item())<10**-precision
    if nll_loss_check: n_correct_nll+=1

print('percentage NLLLoss correctly computed:%s'%(str(n_correct_nll/N_iter*100)))

I got output like:

NLLLoss:
module:tensor(-0.4783)
manual:tensor(-0.4783)
percentage NLLLoss correctly computed:100.0

So far so good, 100% of the computations are correct. Now let us compute CrossEntropyLoss manually with the following.

def compute_crossentropyloss_manual(x,y0):
    """
    x is the vector with shape (batch_size,C)
    y0 shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    loss = 0.
    n_batch, n_class = x.shape
    # print(n_class)
    for x1,y1 in zip(x,y0):
        class_index = int(y1.item())
        loss = loss + torch.log(torch.exp(x1[class_index])/(torch.exp(x1).sum()))
    loss = - loss/n_batch
    return loss

And then repeat the procedure for 1000 randomly initiated batches.

torch.manual_seed(0)
precision = 15

batch_size=10
C = 10

N_iter = 1000
n_correct_CE = 0

criterion2 = nn.CrossEntropyLoss()
for i in range(N_iter):
    x = torch.rand(size=(batch_size,C)).to(torch.float)
    y0 = torch.randint(0,C,size=(batch_size,))

    CEloss = criterion2(x,y0)
    manual_CEloss = compute_crossentropyloss_manual(x,y0)
    if i==0:
        print('CrossEntropyLoss:')
        print('module:%s'%(str(CEloss)))
        print('manual:%s'%(str(manual_CEloss)))

    CE_loss_check = np.abs((CEloss- manual_CEloss).item())<10**-precision
    if CE_loss_check: n_correct_CE+=1

print('percentage CELoss correctly computed :%s'%(str(n_correct_CE/N_iter*100)))

the result is

CrossEntropyLoss:
module:tensor(2.3528)
manual:tensor(2.3528)
percentage CELoss correctly computed :81.39999999999999

I got 81.4% computation correct up to 15 decimal places. Most likely the exp() and the log() are giving a little precision problems, but I don't know exactly how.