2
votes

I am using the a CNN similar to alexnet for a image related regression task. I defined a rmse for the loss function. However, during the training in the first epoch, the loss returned a huge value. But following the second epoch, it dropped to a meaningful value. Here it is:

1/51 [..............................] - ETA: 847s - loss: 104.1821 - acc: 0.2500 - root_mean_squared_error: 104.1821 2/51 [>.............................] - ETA: 470s - loss: 5277326.0910 - acc: 0.5938 - root_mean_squared_error: 5277326.0910 3/51 [>.............................] - ETA: 345s - loss: 3518246.7337 - acc: 0.5000 - root_mean_squared_error: 3518246.7337 4/51 [=>............................] - ETA: 281s - loss: 2640801.3379 - acc: 0.6094 - root_mean_squared_error: 2640801.3379 5/51 [=>............................] - ETA: 241s - loss: 2112661.3062 - acc: 0.5000 - root_mean_squared_error: 2112661.3062 6/51 [==>...........................] - ETA: 214s - loss: 1760566.4758 - acc: 0.4375 - root_mean_squared_error: 1760566.4758 7/51 [===>..........................] - ETA: 194s - loss: 1509067.6495 - acc: 0.4464 - root_mean_squared_error: 1509067.6495 8/51 [===>..........................] - ETA: 178s - loss: 1320442.6319 - acc: 0.4570 - root_mean_squared_error: 1320442.6319 9/51 [====>.........................] - ETA: 165s - loss: 1173734.9212 - acc: 0.4792 - root_mean_squared_error: 1173734.9212 10/51 [====>.........................] - ETA: 155s - loss: 1056369.3193 - acc: 0.4875 - root_mean_squared_error: 1056369.3193 11/51 [=====>........................] - ETA: 146s - loss: 960343.5998 - acc: 0.4943 - root_mean_squared_error: 960343.5998 12/51 [======>.......................] - ETA: 139s - loss: 880320.3762 - acc: 0.5052 - root_mean_squared_error: 880320.3762 13/51 [======>.......................] - ETA: 131s - loss: 812608.7112 - acc: 0.5216 - root_mean_squared_error: 812608.7112 14/51 [=======>......................] - ETA: 125s - loss: 754570.1939 - acc: 0.5402 - root_mean_squared_error: 754570.1939 15/51 [=======>......................] - ETA: 120s - loss: 704269.2443 - acc: 0.5479 - root_mean_squared_error: 704269.2443 16/51 [========>.....................] - ETA: 114s - loss: 660256.3035 - acc: 0.5508 - root_mean_squared_error: 660256.3035 17/51 [========>.....................] - ETA: 109s - loss: 621420.7248 - acc: 0.5607 - root_mean_squared_error: 621420.7248 18/51 [=========>....................] - ETA: 104s - loss: 586900.8398 - acc: 0.5712 - root_mean_squared_error: 586900.8398 19/51 [==========>...................] - ETA: 100s - loss: 556014.6719 - acc: 0.5806 - root_mean_squared_error: 556014.6719 20/51 [==========>...................] - ETA: 95s - loss: 528216.9077 - acc: 0.5875 - root_mean_squared_error: 528216.9077 21/51 [===========>..................] - ETA: 91s - loss: 503065.7743 - acc: 0.5967 - root_mean_squared_error: 503065.7743 22/51 [===========>..................] - ETA: 87s - loss: 480206.3521 - acc: 0.6094 - root_mean_squared_error: 480206.3521 23/51 [============>.................] - ETA: 83s - loss: 459331.8636 - acc: 0.6114 - root_mean_squared_error: 459331.8636 24/51 [=============>................] - ETA: 80s - loss: 440196.2991 - acc: 0.6159 - root_mean_squared_error: 440196.2991 25/51 [=============>................] - ETA: 76s - loss: 422590.8381 - acc: 0.6162 - root_mean_squared_error: 422590.8381 26/51 [==============>...............] - ETA: 73s - loss: 406339.5179 - acc: 0.6178 - root_mean_squared_error: 406339.5179 27/51 [==============>...............] - ETA: 69s - loss: 391292.6992 - acc: 0.6238 - root_mean_squared_error: 391292.6992 28/51 [===============>..............] - ETA: 66s - loss: 377319.9851 - acc: 0.6306 - root_mean_squared_error: 377319.9851 29/51 [===============>..............] - ETA: 63s - loss: 364310.7557 - acc: 0.6336 - root_mean_squared_error: 364310.7557 30/51 [================>.............] - ETA: 60s - loss: 352169.1059 - acc: 0.6385 - root_mean_squared_error: 352169.1059 31/51 [=================>............] - ETA: 57s - loss: 340810.8854 - acc: 0.6401 - root_mean_squared_error: 340810.8854 32/51 [=================>............] - ETA: 53s - loss: 330162.1334 - acc: 0.6455 - root_mean_squared_error: 330162.1334 33/51 [==================>...........] - ETA: 50s - loss: 320158.7622 - acc: 0.6553 - root_mean_squared_error: 320158.7622 34/51 [==================>...........] - ETA: 47s - loss: 310744.0080 - acc: 0.6645 - root_mean_squared_error: 310744.0080 35/51 [===================>..........] - ETA: 44s - loss: 301866.8259 - acc: 0.6714 - root_mean_squared_error: 301866.8259 36/51 [====================>.........] - ETA: 41s - loss: 293483.0129 - acc: 0.6762 - root_mean_squared_error: 293483.0129 37/51 [====================>.........] - ETA: 39s - loss: 285552.8197 - acc: 0.6757 - root_mean_squared_error: 285552.8197 38/51 [=====================>........] - ETA: 36s - loss: 278039.4488 - acc: 0.6752 - root_mean_squared_error: 278039.4488 39/51 [=====================>........] - ETA: 33s - loss: 270911.4670 - acc: 0.6795 - root_mean_squared_error: 270911.4670 40/51 [======================>.......] - ETA: 30s - loss: 264140.2391 - acc: 0.6820 - root_mean_squared_error: 264140.2391 41/51 [=======================>......] - ETA: 27s - loss: 257699.1895 - acc: 0.6852 - root_mean_squared_error: 257699.1895 42/51 [=======================>......] - ETA: 25s - loss: 251564.6846 - acc: 0.6890 - root_mean_squared_error: 251564.6846 43/51 [========================>.....] - ETA: 22s - loss: 245715.4124 - acc: 0.6933 - root_mean_squared_error: 245715.4124 44/51 [========================>.....] - ETA: 19s - loss: 240131.9916 - acc: 0.6960 - root_mean_squared_error: 240131.9916 45/51 [=========================>....] - ETA: 16s - loss: 234796.6948 - acc: 0.7007 - root_mean_squared_error: 234796.6948 46/51 [=========================>....] - ETA: 14s - loss: 229693.3717 - acc: 0.7045 - root_mean_squared_error: 229693.3717 47/51 [==========================>...] - ETA: 11s - loss: 224807.2748 - acc: 0.7055 - root_mean_squared_error: 224807.2748 48/51 [===========================>..] - ETA: 8s - loss: 220125.0731 - acc: 0.7077 - root_mean_squared_error: 220125.0731 49/51 [===========================>..] - ETA: 5s - loss: 215634.5638 - acc: 0.7117 - root_mean_squared_error: 215634.5638 50/51 [============================>.] - ETA: 3s - loss: 211323.1692 - acc: 0.7144 - root_mean_squared_error: 211323.1692 51/51 [============================>.] - ETA: 0s - loss: 207180.6328 - acc: 0.7151 - root_mean_squared_error: 207180.6328 52/51 [==============================] - 143s - loss: 203253.6237 - acc: 0.7157 - root_mean_squared_error: 203253.6237 - val_loss: 44.4203 - val_acc: 0.9878 - val_root_mean_squared_error: 44.4203 Epoch 2/128 1/51 [..............................] - ETA: 117s - loss: 52.6087 - acc: 0.7188 - root_mean_squared_error: 52.6087

How to understand this behavior? Here is my implementation. First define the rmse function:

from keras import backend as K
def root_mean_squared_error(y_true, y_pred):
   return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))

Then for the model:

model.compile(optimizer="rmsprop", loss=root_mean_squared_error, metrics=['accuracy', root_mean_squared_error])

Then fit the model:

estimator = alexmodel()
datagen = ImageDataGenerator()
datagen.fit(x_train)
start = time.time()
history = estimator.fit_generator(datagen.flow(x_train, x_train,batch_size=batch_size, shuffle=True),
           epochs=epochs,
           steps_per_epoch=x_train.shape[0]/batch_size,
           validation_data=(x_test, y_test))
end = time.time()

Can anyone tell me why is that? Anything potential wrong?

1
If the loss is going down and the accuracy is going up.... all seems fine....Daniel Möller

1 Answers

1
votes

So - it's important to normalize your data. It seems that you haven't normalized your target and as a network is usually initialized in such way that it will produce small values at the beginning - this made your loss so huge during the first epoch. So I still advise you to normalize your target (by either using StandardScaler or MinMaxScaller) because a need to produce high scale values will make the weights of your network to have much higher absolute values which are something which you should prevent your network from.