1
votes

I am dealing with a semantic segmentation problem where the two classes in which I am interested (in addition to background) are quiet unbalanced in the image pixels. I am actually using sparse categorical cross entropy as a loss, due to the way in which training masks are encoded. Is there any version of it which takes into account class weights? I have not been able to find it, and not even the original source code of sparse_categorical_cross_entropy. I never explored the tf source code before, but the link to source code from API page doesn't seem to link to a real implementation of the loss function.

2
Maybe this could be adapted to work for segmentation. I'm to inexperienced with both python and keras to do such a thing. Maybe it already works(?) however, it has "dimension problems" when I pass it an array with class weights. - Manuel Popp

2 Answers

2
votes

As far as I know you can use class weights in model.fit for any loss function. I have used it with categorical_cross_entropy and it works. It just weights the loss with the class weight so I see no reason it should not work with sparse_categorical_cross_entropy.

0
votes

I think this is the solution to weigh sparse_categorical_crossentropy in Keras. They use the following to add a "second mask" (containing the weights for each class of the mask image) to the dataset.

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights


train_dataset.map(add_sample_weights).element_spec

Then they just use tf.keras.losses.SparseCategoricalCrossentropy as loss function and fit like:

weighted_model.fit(
    train_dataset.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)