I want to use my Tensorflow algorithm in an Android app. The Tensorflow Android example starts by downloading a GraphDef that contains the model definition and weights (in a *.pb file). Now this should be from my Scikit Flow algorithm (part of Tensorflow).
At the first glance it seems easy you just have to say classifier.save('model/') but the files saved to that folder are not *.ckpt, *.def and certainly not *.pb. Instead you have to deal with a *.pbtxt and a checkpoint (without ending) file.
I'm stuck there since quite a while. Here a code example to export something:
#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
The files you get are:
- checkpoint
- graph.pbtxt
- model.ckpt-1.meta
- model.ckpt-1-00000-of-00001
- model.ckpt-200.meta
- model.ckpt-200-00000-of-00001
Many possible workarounds I found would require having the GraphDef in a variable (don't know how with Scikit Flow). Or a Tensorflow session which doesn't seem to be required using Scikit Flow.