0
votes

I am trying to write a custom loss function in keras where I need to weight the MSE between y_true and y_pred (shape: (batch_size, 64, 64)) by the output of an intermediate layer (whose shape is (batch_size, 1)).

The op I need is simply weight (multiply) each batch element of the MSE by a factor, i.e. the corresponding batch element in the weight_tensor.

I tried the following

def loss_aux_wrapper(weight_tensor):
    def loss_aux(y_true, y_pred):
        K.print_tensor(weight_tensor, message='weight = ')
        _shape = K.shape(y_true)
        return K.reshape(K.batch_dot(K.batch_flatten(mse(y_true, y_pred)), weight_tensor, axes=[1,1]),_shape)
    return loss_aux

but I get

tensorflow.python.framework.errors_impl.InvalidArgumentError:  In[0] mismatch In[1] shape: 4096 vs. 1: [32,1,4096] [32,1,1] 0 0      [[node loss/aux_motor_output_loss/MatMul (defined at /code/icub_sensory_enhancement/scripts/models.py:327) ]] [Op:__inference_keras_scratch_graph_6548]

K.print_tensor does not output anything, I believe because it is called at compile time?

Any suggestion is greatly appreciated!

1

1 Answers

0
votes

In order to weight an MSE loss function, you may use the sample_weight= argument while calling the mse function. As per the docs,

If sample_weight is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sample_weight vector. If the shape of sample_weight is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of y_pred is scaled by the corresponding value of sample_weight. (Note ondN-1: all loss functions reduce by 1 dimension, usually axis=-1.)

In your case, the weight_tensor has a shape ( batch size , 1 ). So first we need to reshape it, like,

reshaped_weight_tensor = K.reshape( weight_tensor , shape=( batch_size ) )

This tensor then can be used with MeanSquaredError,

def loss_aux_wrapper(weight_tensor):

    def loss_aux(y_true, y_pred):
        reshaped_weight_tensor = K.reshape( weight_tensor , shape=( batch_size ) )
        return mse( y_true , y_pred , sample_weight=reshaped_weight_tensor )
        
    return loss_aux