I'm trying to train multiple Keras models in a loop to evaluate different parameters. To avoid memory problems, I call K.clear_session()
, before each model building.
After adding the K.clear_session()
call, I started getting this error when saving the second model.
raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("level1/kernel:0", shape=(3, 3, 3, 16), dtype=float32_ref) is not an element of this graph. During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "/home/gus/workspaces/wpy/cnn/srs/train_generators.py", line 286, in train_models(model_defs) File "/home/gus/workspaces/wpy/cnn/srs/train_generators.py", line 196, in train_models model.save(file_path) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/network.py", line 1090, in save save_model(self, filepath, overwrite, include_optimizer) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/saving.py", line 382, in save_model _serialize_model(model, f, include_optimizer) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/saving.py", line 97, in _serialize_model weight_values = K.batch_get_value(symbolic_weights) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2420, in batch_get_value return get_session().run(ops) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run run_metadata_ptr) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1137, in _run self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 471, in init self._fetch_mapper = _FetchMapper.for_fetch(fetches) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 261, in for_fetch return _ListFetchMapper(fetch) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 370, in init self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 370, in self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 271, in for_fetch return _ElementFetchMapper(fetches, contraction_fn) File "/home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 307, in init 'Tensor. (%s)' % (fetch, str(e))) ValueError: Fetch argument cannot be interpreted as a Tensor. (Tensor Tensor("level1/kernel:0", shape=(3, 3, 3, 16), dtype=float32_ref) is not an element of this graph.)
The code basically:
while <models to train>:
K.clear_session()
model = modeldef.build() # everything that has a tensor goes here and just here
# create generators from directories
opt = Adam(lr=0.001, decay=0.001 / epochs)
model.compile(...)
H = model.fit_generator(...)
model.save(file_path) # --> here it crashes
No matter how deep the network is, a super simple ConvNet like this makes the code fail when saving:
class SuperSimpleCNN:
def __init__(self, img_size, depth):
self.img_size = img_size
self.depth = depth
def build(self):
init = Input(shape=(self.img_size, self.img_size, self.depth))
x = Convolution2D(16, (3, 3), padding='same', name='level1')(init)
x = Activation('relu')(x)
out = Convolution2D(self.depth, (5, 5), padding='same', name='output')(x)
model = Model(init, out)
return model
Looking similar problems, I understand the problem is due to the fact that keras shares a global session, and different graphs from different models can't be mixed.
But I don't understand why using K.clear_session()
before each model makes the save operation fail when iteration>1. And why the difference between Tensor and Variable.
<tf.Variable 'level1/kernel:0' shape=(3, 3, 3, 16) dtype=float32_ref> cannot be interpreted as a Tensor
Can anyone help?
Thank you.