Here I give the full formula to manually compute pytorch's CrossEntropyLoss. There is a little precision problem you will see later; do post an answer if you know the exact reason.
First, understand how NLLLoss works. Then CrossEntropyLoss is very similar, except it is NLLLoss with Softmax inside.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def compute_nllloss_manual(x,y0):
"""
x is the vector with shape (batch_size,C)
Note: official example uses log softmax(some vector) as x, so it becomes CELoss.
y0 shape is the same (batch_size), whose entries are integers from 0 to C-1
Furthermore, for C>1 classes, the other classes are ignored (see below
"""
loss = 0.
n_batch, n_class = x.shape
# print(n_class)
for x1,y1 in zip(x,y0):
class_index = int(y1.item())
loss = loss + x1[class_index] # other class terms, ignore.
loss = - loss/n_batch
return loss
We see from the formula that it is NOT like the standard prescribed NLLLoss because the "other class" terms are ignored (see the comment in the code). Also, remember that Pytorch often processes things in batches. In the following code, we randomly initiate 1000 batches to verify that the formula is correct up to 15 decimal places.
torch.manual_seed(0)
precision = 15
batch_size=10
C = 10
N_iter = 1000
n_correct_nll = 0
criterion = nn.NLLLoss()
for i in range(N_iter):
x = torch.rand(size=(batch_size,C)).to(torch.float)
y0 = torch.randint(0,C,size=(batch_size,))
nll_loss = criterion(x,y0)
manual_nll_loss = compute_nllloss_manual(x,y0)
if i==0:
print('NLLLoss:')
print('module:%s'%(str(nll_loss)))
print('manual:%s'%(str(manual_nll_loss)))
nll_loss_check = np.abs((nll_loss- manual_nll_loss).item())<10**-precision
if nll_loss_check: n_correct_nll+=1
print('percentage NLLLoss correctly computed:%s'%(str(n_correct_nll/N_iter*100)))
I got output like:
NLLLoss:
module:tensor(-0.4783)
manual:tensor(-0.4783)
percentage NLLLoss correctly computed:100.0
So far so good, 100% of the computations are correct. Now let us compute CrossEntropyLoss manually with the following.
def compute_crossentropyloss_manual(x,y0):
"""
x is the vector with shape (batch_size,C)
y0 shape is the same (batch_size), whose entries are integers from 0 to C-1
"""
loss = 0.
n_batch, n_class = x.shape
# print(n_class)
for x1,y1 in zip(x,y0):
class_index = int(y1.item())
loss = loss + torch.log(torch.exp(x1[class_index])/(torch.exp(x1).sum()))
loss = - loss/n_batch
return loss
And then repeat the procedure for 1000 randomly initiated batches.
torch.manual_seed(0)
precision = 15
batch_size=10
C = 10
N_iter = 1000
n_correct_CE = 0
criterion2 = nn.CrossEntropyLoss()
for i in range(N_iter):
x = torch.rand(size=(batch_size,C)).to(torch.float)
y0 = torch.randint(0,C,size=(batch_size,))
CEloss = criterion2(x,y0)
manual_CEloss = compute_crossentropyloss_manual(x,y0)
if i==0:
print('CrossEntropyLoss:')
print('module:%s'%(str(CEloss)))
print('manual:%s'%(str(manual_CEloss)))
CE_loss_check = np.abs((CEloss- manual_CEloss).item())<10**-precision
if CE_loss_check: n_correct_CE+=1
print('percentage CELoss correctly computed :%s'%(str(n_correct_CE/N_iter*100)))
the result is
CrossEntropyLoss:
module:tensor(2.3528)
manual:tensor(2.3528)
percentage CELoss correctly computed :81.39999999999999
I got 81.4% computation correct up to 15 decimal places. Most likely the exp() and the log() are giving a little precision problems, but I don't know exactly how.