diff --git a/gqcnn/model/tf/network_tf.py b/gqcnn/model/tf/network_tf.py index 8b90d135..d2a86551 100644 --- a/gqcnn/model/tf/network_tf.py +++ b/gqcnn/model/tf/network_tf.py @@ -552,6 +552,19 @@ def initialize_network(self, if add_sigmoid: self.add_sigmoid_to_output() + # Freeze graph. Make it 'True' to freeze this graph + if False: + from tensorflow.python.framework import graph_io, graph_util + self.open_session() + frozen = graph_util.convert_variables_to_constants( + self._sess, self._sess.graph_def, ["softmax/Softmax"]) + graph_io.write_graph(frozen, + ".", + "inference_graph_frozen.pb", + as_text=False) + self.close_session() + self._logger.info("Wrote frozen graph") + # Create feed tensors for prediction. self._input_im_arr = np.zeros((self._batch_size, self._im_height, self._im_width, self._num_channels))