I am dealing with a semantic segmentation problem where the two classes in which I am interested (in addition to background) are quiet unbalanced in the image pixels. I am actually using sparse categorical cross entropy as a loss, due to the way in which training masks are encoded. Is there any version of it which takes into account class weights? I have not been able to find it, and not even the original source code of sparse_categorical_cross_entropy. I never explored the tf source code before, but the link to source code from API page doesn't seem to link to a real implementation of the loss function.
1
votes
2 Answers
2
votes
0
votes
I think this is the solution to weigh sparse_categorical_crossentropy in Keras.
They use the following to add a "second mask" (containing the weights for each class of the mask image) to the dataset.
def add_sample_weights(image, label):
# The weights for each class, with the constraint that:
# sum(class_weights) == 1.0
class_weights = tf.constant([2.0, 2.0, 1.0])
class_weights = class_weights/tf.reduce_sum(class_weights)
# Create an image of `sample_weights` by using the label at each pixel as an
# index into the `class weights` .
sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))
return image, label, sample_weights
train_dataset.map(add_sample_weights).element_spec
Then they just use tf.keras.losses.SparseCategoricalCrossentropy as loss function and fit like:
weighted_model.fit(
train_dataset.map(add_sample_weights),
epochs=1,
steps_per_epoch=10)