0
votes

I want to save a Keras model with a custom activation function. Saving just the weights is not an option since want to save the optimizer state. If I use a built in activation function like relu, save works with no problem. But if I use a custom activation function, I get an error; I assume this is because the activation function itself cannot be stored. My custom activation function is:

def lrelu(x):
    return tf.maximum(x * 0.2, x)

If I use it with

keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="lrelu")(x)

I get the error below. Is there any work around?

Traceback (most recent call last):
  File "train_K.py", line 191, in <module>
    model.save(model_fn)
  File "C:\ProgramData\Anaconda3\envs\py352\lib\site-packages\keras\engine\topol
ogy.py", line 2576, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "C:\ProgramData\Anaconda3\envs\py352\lib\site-packages\keras\models.py",
line 111, in save_model
    'config': model.get_config()
  File "C:\ProgramData\Anaconda3\envs\py352\lib\site-packages\keras\engine\topol
ogy.py", line 2349, in get_config
    layer_config = layer.get_config()
  File "C:\ProgramData\Anaconda3\envs\py352\lib\site-packages\keras\layers\convo
lutional.py", line 466, in get_config
    config = super(Conv2D, self).get_config()
  File "C:\ProgramData\Anaconda3\envs\py352\lib\site-packages\keras\layers\convo
lutional.py", line 223, in get_config
    'activation': activations.serialize(self.activation),
  File "C:\ProgramData\Anaconda3\envs\py352\lib\site-packages\keras\activations.
py", line 92, in serialize
    return activation.__name__
AttributeError: 'Activation' object has no attribute '__name__'
1
Defining custom activation functions in keras should be in the keras standard form (activation class type) to properly define an activation function refer to: linka-sam

1 Answers

0
votes

I was able to "solve" the problem by putting the custom activation function in the same file as the call to store the model. So for example the following fails with the above error message:

# the file network_K.py has the definition of the function lrelu(x):
from network_K import *

model.save_weights("temp.h4")

But the following works fine:

# the file network_K.py has the definition of the function lrelu(x):
from network_K import *


def lrelu(x):
    return tf.maximum(x * 0.2, x)

get_custom_objects().update({'lrelu': Activation(lrelu)})

model.save_weights("temp.h4")

While I am considering this question answered (since I have a work around) I am still curious as to why including with import is not sufficient ....