3
votes

Given batched RGB images as input, shape=(batch_size, width, height, 3)

And a multiclass target represented as one-hot, shape=(batch_size, width, height, n_classes)

And a model (Unet, DeepLab) with softmax activation in last layer.

I'm looking for weighted categorical-cross-entropy loss funciton in kera/tensorflow.

The class_weight argument in fit_generator doesn't seems to work, and I didn't find the answer here or in https://github.com/keras-team/keras/issues/2115.

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        # y_true, y_pred shape is (batch_size, width, height, n_classes)
        loos = ?...
        return loss

    return wcce
3
By multiclass target do you mean more than 1 possible outcomes are considered?SajanGohil
What do you mean by "outcome"? Multiclass=Different pixel value indicate different class. And you can have more than 2 classes. (2 classes=binary classification)Mendi Barel
Multiclass classification is a different kind of classification problem where more than 1 class can be true, I got confused with that.SajanGohil

3 Answers

3
votes

I will answer my question:

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        Kweights = K.constant(weights)
        if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
        y_true = K.cast(y_true, y_pred.dtype)
        return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
    return wcce

Usage:

loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)
1
votes

I'm using the Generalized Dice Loss. It works better than the Weighted Categorical Crossentropy in my case. My implementation is in PyTorch, however, it should be fairly easy to translate it.

class GeneralizedDiceLoss(nn.Module):
    def __init__(self):
        super(GeneralizedDiceLoss, self).__init__()

    def forward(self, inp, targ):
        inp = inp.contiguous().permute(0, 2, 3, 1)
        targ = targ.contiguous().permute(0, 2, 3, 1)

        w = torch.zeros((targ.shape[-1],))
        w = 1. / (torch.sum(targ, (0, 1, 2))**2 + 1e-9)

        numerator = targ * inp
        numerator = w * torch.sum(numerator, (0, 1, 2))
        numerator = torch.sum(numerator)

        denominator = targ + inp
        denominator = w * torch.sum(denominator, (0, 1, 2))
        denominator = torch.sum(denominator)

        dice = 2. * (numerator + 1e-9) / (denominator + 1e-9)

        return 1. - dice
0
votes

This issue might be similar to: Unbalanced data and weighted cross entropy which has an accepted answer.