I'm in the process of finetuning a BERT model to the long answer task in the Natural Questions dataset. I'm training the model just like a SQuAD model (predicting start and end tokens).
I use Huggingface and PyTorch.
So the targets and labels have a shape/size of [batch, 2]. My problem is that I can't input "multi-targets" which I think is refering to the fact that the last shape is 2.
RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:18
Should I choose another loss function or is there another way to bypass this problem?
This code I'm using:
def loss_fn(preds, targets):
return nn.CrossEntropyLoss()(preds,labels)
class DecoderModel(nn.Module):
def __init__(self, model_args, encoder_config, loss_fn):
super(DecoderModel, self).__init__()
# ...
def forward(self, pooled_output, labels):
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
start_logits, end_logits = logits.split(1, dim = -1)
start_logit = torch.squeeze(start_logits, axis=-1)
end_logit = torch.squeeze(end_logits, axis=-1)
# Concatenate into a "label"
preds = torch.cat((start_logits, end_logits), -1)
# Calculate loss
loss = self.loss_fn(
preds = preds,
labels = labels)
return loss, preds
The targets properties are: torch.int64 & [3,2]
The predictions properties are: torch.float32 & [3,2]
SOLVED - this is my solution
def loss_fn(preds:list, labels):
start_token_labels, end_token_labels = labels.split(1, dim = -1)
start_token_labels = start_token_labels.squeeze(-1)
end_token_labels = end_token_labels.squeeze(-1)
print('*'*50)
print(preds[0].shape) # preds [0] and [1] has the same shape and dtype
print(preds[0].dtype) # preds [0] and [1] has the same shape and dtype
print(start_token_labels.shape) # labels [0] and [1] has the same shape and dtype
print(start_token_labels.dtype) # labels [0] and [1] has the same shape and dtype
start_loss = nn.CrossEntropyLoss()(preds[0], start_token_labels)
end_loss = nn.CrossEntropyLoss()(preds[1], end_token_labels)
avg_loss = (start_loss + end_loss) / 2
return avg_loss
Basically I'm splitting the logits (just not concatinating them) and the labels. I then do Cross Entropy loss on both of them and at last taking the average loss between the two. Hope this gives you an idea to solve your own problem!