I want to convert a model to tflite format. However, I keep getting an error that the operator BroadcastTo is not supported. The only way I have been able to get around this error is by defining by model as a concrete function. How do I train just a concrete function, is it even possible?
(Not my actual model, just an minimal example of the error)
# -------------------- Doesn't Work --------------------
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(CustomLayer, self).__init__()
def call(self, input):
trans = tf.ones([1, 25, 37, 12])
trans = tf.math.add(trans, input)
m1s = tf.ones([1, 25, 37, 12, 5, 5])
reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1])
f = tf.multiply(reshape, m1s)
return f
input = tf.keras.Input(shape=(1), dtype=tf.float32)
f = CustomLayer(1)(input)
model = tf.keras.Model(inputs=input, outputs=f)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("model.tflite", "wb").write(tflite_model)
# -------------------- Concrete Function (Works) --------------------
root = tf.Module()
root.var = None
@tf.function
def example(number):
trans = tf.ones([1, 25, 37, 12])
trans = tf.add(trans, number)
m1s = tf.ones([1, 25, 37, 12, 5, 5])
reshape = tf.reshape(trans, [1, 25, 37, 12, 1, 1])
f = tf.multiply(reshape, m1s)
return f
root.func = example
concrete_func = root.func.get_concrete_function(3)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
open("model.tflite", "wb").write(tflite_model)
Note that I have already tried the following solutions:
- Defining the model in Keras (so it can be trained easily) and using
tf.lite.TFLiteConverter.from_keras_model - Saving the Keras model as a SavedModel and using
tf.lite.TFLiteConverter.from_saved_model - Saving the Keras model as a SavedModel and getting the concrete function from it using
concrete_func = model.signatures[ tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
I know it is also possible to make a custom operator, but that would require advanced knowledge of tensorflow's C++ API, knowing how BroadcastTo works internally, knowing where to put the files, compiling a custom AAR, and building a custom JNI layer.