I am trying to write a custom loss function in keras where I need to weight the MSE between y_true and y_pred (shape: (batch_size, 64, 64)) by the output of an intermediate layer (whose shape is (batch_size, 1)).
The op I need is simply weight (multiply) each batch element of the MSE by a factor, i.e. the corresponding batch element in the weight_tensor.
I tried the following
def loss_aux_wrapper(weight_tensor):
def loss_aux(y_true, y_pred):
K.print_tensor(weight_tensor, message='weight = ')
_shape = K.shape(y_true)
return K.reshape(K.batch_dot(K.batch_flatten(mse(y_true, y_pred)), weight_tensor, axes=[1,1]),_shape)
return loss_aux
but I get
tensorflow.python.framework.errors_impl.InvalidArgumentError: In[0] mismatch In[1] shape: 4096 vs. 1: [32,1,4096] [32,1,1] 0 0 [[node loss/aux_motor_output_loss/MatMul (defined at /code/icub_sensory_enhancement/scripts/models.py:327) ]] [Op:__inference_keras_scratch_graph_6548]
K.print_tensor does not output anything, I believe because it is called at compile time?
Any suggestion is greatly appreciated!