I am searching for a way to use Keras Model.predict()
function in a sub-process.
I am using Keras 2.3.1 and TensorFlow 2.0.0. (I tried Keras 2.25 and TensorFlow 1.14)
The following code throws that error.
import itertools
import random
from abc import ABC
from multiprocessing import Pool as Pool
import numpy as np
from keras.engine.saving import load_model
from keras.models import Sequential
from keras.layers import Dense, Activation
class Pre(ABC):
class Prediction(Pre):
def __init__(self):
model = Sequential([
Dense(32, input_shape=(2,)),
x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0, 1], [1, 0], [1, 0], [0, 1]])
model.fit(x, y, epochs=20)
self.model = load_model('temp')
self.modifier = 2
def predict(self, input_array):
prediction = self.model.predict(np.array([input_array]))[0]
prediction += self.modifier
return prediction[0]
class B:
def __init__(self):
self.pred = Prediction()
def calculate_something(pred_inner: B, modifier: int):
pred_inner.modifier = modifier
sum_all = sum(pred_inner.pred.predict(np.array([random.choice([0, 1]), random.choice([0, 1])])) for _ in range(100))
# do some modifi
return (pred_inner,
if __name__ == '__main__':
probe_size = 100
pred = B()
for i in range(1000):
with Pool() as pool:
results = pool.starmap(calculate_something, zip(itertools.repeat(pred),
[probe_size for _ in range(i)]))
for r in results:
Since I call the predict
function in a sub-process it runs into a conflict with its own sub-process.
My Networks are very small so i think the multiprocessing is not strictly necessary is there any way to deactivate multiprocessing in Keras and TensorFlow?
Or is there another API i could use instead of Keras/TensorFlow?
Exception in thread Thread-24: Traceback (most recent call last): File "C:\Python37\lib\threading.py", line 926, in _bootstrap_inner self.run() File "C:\Python37\lib\threading.py", line 870, in run self._target(*self._args, **self._kwargs) File "C:\Python37\lib\multiprocessing\pool.py", line 470, in _handle_results task = get() File "C:\Python37\lib\multiprocessing\connection.py", line 251, in recv return _ForkingPickler.loads(buf.getbuffer()) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\network.py", line 1334, in __setstate__ model = saving.unpickle_model(state) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\saving.py", line 604, in unpickle_model return _deserialize_model(h5dict) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\saving.py", line 274, in _deserialize_model model = model_from_config(model_config, custom_objects=custom_objects) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\saving.py", line 627, in model_from_config return deserialize(config, custom_objects=custom_objects) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\layers\__init__.py", line 168, in deserialize printable_module_name='layer') File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\utils\generic_utils.py", line 147, in deserialize_keras_object list(custom_objects.items()))) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\sequential.py", line 302, in from_config model.add(layer) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\sequential.py", line 162, in add name=layer.name + '_input') File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\input_layer.py", line 178, in Input input_tensor=tensor) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper return func(*args, **kwargs) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\engine\input_layer.py", line 87, in __init__ name=self.name) File "C:\Users\phhor\PycharmProjects\py_doku\venv37\lib\site-packages\keras\backend\tensorflow_backend.py", line 73, in symbolic_fn_wrapper if _SYMBOLIC_SCOPE.value: Attrib uteError: '_thread._local' object has no attribute 'value'