You can load output of one graph into another using input_map
in import_graph_def
. Also you have to rename the while_context
because there is one while
function for every graph. Something like this:
def get_frozen_graph(graph_file):
"""Read Frozen Graph file from disk."""
with tf.gfile.GFile(graph_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def rename_frame_name(graphdef, suffix):
for n in graphdef.node:
if "while" in n.name:
if "frame_name" in n.attr:
n.attr["frame_name"].s = str(n.attr["frame_name"]).replace("while_context",
"while_context" + suffix).encode('utf-8')
...
l1_graph = tf.Graph()
with l1_graph.as_default():
trt_graph1 = get_frozen_graph(pb_fname1)
[tf_input1, tf_scores1, tf_boxes1, tf_classes1, tf_num_detections1] = tf.import_graph_def(trt_graph1,
return_elements=['image_tensor:0', 'detection_scores:0', 'detection_boxes:0', 'detection_classes:0','num_detections:0'])
input1 = tf.identity(tf_input1, name="l1_input")
boxes1 = tf.identity(tf_boxes1[0], name="l1_boxes")
scores1 = tf.identity(tf_scores1[0], name="l1_scores")
classes1 = tf.identity(tf_classes1[0], name="l1_classes")
num_detections1 = tf.identity(tf.dtypes.cast(tf_num_detections1[0], tf.int32), name="l1_num_detections")
...
tf_out =
...
connected_graph = tf.Graph()
with connected_graph.as_default():
l1_graph_def = l1_graph.as_graph_def()
g1name = 'ved'
rename_frame_name(l1_graph_def, g1name)
tf.import_graph_def(l1_graph_def, name=g1name)
...
trt_graph2 = get_frozen_graph(pb_fname2)
g2name = 'level2'
rename_frame_name(trt_graph2, g2name)
[tf_scores, tf_boxes, tf_classes, tf_num_detections] = tf.import_graph_def(trt_graph2,
input_map={'image_tensor': tf_out},
return_elements=['detection_scores:0', 'detection_boxes:0', 'detection_classes:0','num_detections:0'])
with connected_graph.as_default():
print('\nSaving...')
cwd = os.getcwd()
path = os.path.join(cwd, 'saved_model')
shutil.rmtree(path, ignore_errors=True)
inputs_dict = {
"image_tensor": tf_input
}
outputs_dict = {
"detection_boxes_l1": tf_boxes_l1,
"detection_scores_l1": tf_scores_l1,
"detection_classes_l1": tf_classes_l1,
"max_num_detection": tf_max_num_detection,
"detection_boxes_l2": tf_boxes_l2,
"detection_scores_l2": tf_scores_l2,
"detection_classes_l2": tf_classes_l2
}
tf.saved_model.simple_save(
tf_sess_main, path, inputs_dict, outputs_dict
)
print('Ok')