3
votes

I want to use my Tensorflow algorithm in an Android app. The Tensorflow Android example starts by downloading a GraphDef that contains the model definition and weights (in a *.pb file). Now this should be from my Scikit Flow algorithm (part of Tensorflow).

At the first glance it seems easy you just have to say classifier.save('model/') but the files saved to that folder are not *.ckpt, *.def and certainly not *.pb. Instead you have to deal with a *.pbtxt and a checkpoint (without ending) file.

I'm stuck there since quite a while. Here a code example to export something:

#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics

#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)

The files you get are:

  • checkpoint
  • graph.pbtxt
  • model.ckpt-1.meta
  • model.ckpt-1-00000-of-00001
  • model.ckpt-200.meta
  • model.ckpt-200-00000-of-00001

Many possible workarounds I found would require having the GraphDef in a variable (don't know how with Scikit Flow). Or a Tensorflow session which doesn't seem to be required using Scikit Flow.

1
have you managed to find a solution? - idoshamun
I decided to use Scikit Flow for experimenting (how many layers do I need for my NN etc.) and recreated the model with pure tensorflow then. I then avoided the whole freeze_graph bazel stuff by creating a second model with the already trained weights as constants (switched to iOS but might be the same for Android). That's not really a recommendation just the path I took - CodingYourLife

1 Answers

2
votes

To save as pb file, you need to extract the graph_def from the constructed graph. You can do that as--

from tensorflow.python.framework import tensor_shape, graph_util
from tensorflow.python.platform import gfile
sess = tf.Session()
final_tensor_name = 'results:0'     #Replace final_tensor_name with name of the final tensor in your graph
#########Build your graph and train########
## Your tensorflow code to build the graph
###########################################

outpt_filename = 'output_graph.pb'
output_graph_def = sess.graph.as_graph_def()
with gfile.FastGFile(outpt_filename, 'wb') as f:
  f.write(output_graph_def.SerializeToString())

If you want to convert your trained variables to constants (to avoid using ckpt files to load the weights), you can use:

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])

Hope this helps!