diff --git a/cfg/examples/replication/dex-net_2.0.yaml b/cfg/examples/replication/dex-net_2.0.yaml index f2ba3ac9..7852e06f 100644 --- a/cfg/examples/replication/dex-net_2.0.yaml +++ b/cfg/examples/replication/dex-net_2.0.yaml @@ -64,6 +64,8 @@ policy: metric: type: gqcnn gqcnn_model: models/GQCNN-2.0 + # openvino: OFF|CPU|GPU|MYRIAD + openvino: OFF crop_height: 96 crop_width: 96 diff --git a/cfg/examples/replication/dex-net_2.1.yaml b/cfg/examples/replication/dex-net_2.1.yaml index 93391ea9..5fbc48a2 100644 --- a/cfg/examples/replication/dex-net_2.1.yaml +++ b/cfg/examples/replication/dex-net_2.1.yaml @@ -64,7 +64,9 @@ policy: metric: type: gqcnn gqcnn_model: models/GQCNN-2.1 - + # openvino: OFF|CPU|GPU|MYRIAD + openvino: OFF + crop_height: 96 crop_width: 96 diff --git a/cfg/examples/replication/dex-net_3.0.yaml b/cfg/examples/replication/dex-net_3.0.yaml index e8b9bca2..9e551f4e 100644 --- a/cfg/examples/replication/dex-net_3.0.yaml +++ b/cfg/examples/replication/dex-net_3.0.yaml @@ -58,6 +58,8 @@ policy: metric: type: gqcnn gqcnn_model: models/GQCNN-3.0 + # openvino: OFF|CPU|GPU|MYRIAD + openvino: OFF crop_height: 96 crop_width: 96 diff --git a/cfg/examples/replication/dex-net_4.0_pj.yaml b/cfg/examples/replication/dex-net_4.0_pj.yaml index 0e970c56..f4f8a964 100644 --- a/cfg/examples/replication/dex-net_4.0_pj.yaml +++ b/cfg/examples/replication/dex-net_4.0_pj.yaml @@ -63,6 +63,8 @@ policy: metric: type: gqcnn gqcnn_model: models/GQCNN-4.0-PJ + # openvino: OFF|CPU|GPU|MYRIAD + openvino: OFF crop_height: 96 crop_width: 96 diff --git a/cfg/examples/replication/dex-net_4.0_suction.yaml b/cfg/examples/replication/dex-net_4.0_suction.yaml index 1d8d5d9c..3bebcfe0 100644 --- a/cfg/examples/replication/dex-net_4.0_suction.yaml +++ b/cfg/examples/replication/dex-net_4.0_suction.yaml @@ -57,6 +57,8 @@ policy: metric: type: gqcnn gqcnn_model: models/GQCNN-4.0-SUCTION + # openvino: OFF|CPU|GPU|MYRIAD + openvino: OFF crop_height: 96 crop_width: 96 diff --git a/docs/source/images/gqcnn_openvino.png b/docs/source/images/gqcnn_openvino.png new file mode 100755 index 00000000..b9ef1402 Binary files /dev/null and b/docs/source/images/gqcnn_openvino.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index d985dfff..310274a7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -111,6 +111,12 @@ If you use the code, datasets, or models in a publication, please cite the appro replication/replication.rst +.. toctree:: + :maxdepth: 2 + :caption: Enabling OpenVINO™ + + openvino/openvino.rst + .. toctree:: :maxdepth: 2 :caption: Benchmarks diff --git a/docs/source/openvino/openvino.rst b/docs/source/openvino/openvino.rst new file mode 100644 index 00000000..c60322a5 --- /dev/null +++ b/docs/source/openvino/openvino.rst @@ -0,0 +1,72 @@ +Enabling OpenVINO™ +~~~~~~~~~~~~~~~~~~ + +This tutorial introduces how to enable OpenVINO™ for GQCNN deployment on Intel® devices. + +Intel® Distribution of OpenVINO™ (Open Visual Inference & Neural network Optimization) toolkit, based on convolutional neural networks (CNNs), extends computer visoin workloads across Intel® hardware (including accelerators) and maximizes performance. The toolkit enables deep learning inference at the edge computation, and supports heterogeneous execution across various compution vision devices -- CPU, GPU, Intel® Movidius™ NCS2, and FPGA -- using a **common** API. + +.. image:: ../images/gqcnn_openvino.png + :width: 440 px + :height: 250 px + :align: center + +Install OpenVINO™ Toolkit +========================= +The toolkit is available from open source project `Intel® OpenVINO™ Toolkit`_. + +.. note:: GQCNN uses two layers, RandomUniform and Floor, which are not supported by the 2019_R3 release of OpenVINO™. `PR #338 `_ adds the support for these two layers. + +You may get start with `Build Inference Engine`_. The ``Introduction`` section lists supported device types. The ``Build on Linux* Systems`` section tells how to build and install the toolkit. Here're the CMake options for reference. Need adaption to your specific environment. :: + + $ cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local -DGEMM=OPENBLAS -DBLAS_INCLUDE_DIRS=/usr/include/x86_64-linux-gnu -DBLAS_LIBRARIES=/usr/lib/x86_64-linux-gnu/openblas/libblas.so -DENABLE_MKL_DNN=ON -DENABLE_CLDNN=ON -DENABLE_PYTHON=ON -DPYTHON_EXECUTABLE=`which python3.6` -DPYTHON_LIBRARY=/usr/lib/x86_64-linux-gnu/libpython3.6m.so -DPYTHON_INCLUDE_DIR=/usr/include/python3.6 .. + +Then install the ``Model Optimizer``. :: + + $ cd model_optimizer + $ sudo pip3 install -r requirements*.txt + +And setup environment for the toolkit. :: + + $ export InferenceEngine_DIR=/inference-engine/build + $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/inference-engine/bin/intel64/Release/lib + $ export PYTHONPATH=/inference-engine/bin/intel64/Release/lib/python_api/python3.6:/model-optimizer:$PYTHONPATH + +Freeze a GQCNN Model +==================== +A frozen graph file is expected by the Model Optimzer of OpenVINO™. This is done in ``gqcnn.model.tf.GQCNNTF.initialize_network()`` right after the graph is built. Seek into the code pieces below. :: + + # Freeze graph. Make it 'True' to freeze this graph + if False: + +Switch the ``if`` condition to ``True``, run an example policy with a specific GQCNN model (refer to ``scripts/policies/run_all_dex-net__examples.sh``), the frozen graph will be created into the file named ``inference_graph_frozen.pb`` + +Convert a GQCNN Model +===================== +``mo_tf.py`` is the Model Optimizer script for converting a Tensorflow model. :: + + $ sudo python3 /model-optimizer/mo_tf.py --input_model inference_graph_frozen.pb --data_type FP16 --output_dir /models/OpenVINO//FP16 + $ sudo python3 /model-optimizer/mo_tf.py --input_model inference_graph_frozen.pb --data_type FP32 --output_dir /models/OpenVINO//FP32 + +Parameters passed to the conversion script: + #. ``input_model`` the frozen tensorflow model to be converted. + #. ``output_dir`` the directory of the converted model. + #. ``data_type`` data type of the converted model. + +For more detail instructions on model conversion, refer to the `OpenVINO™ Docs`_. + +.. note:: ``gqcnn.model.openvino.GQCNNOpenVINO.load_openvino()`` expect to load an OpenVINO model from ``models/OpenVINO//FP16``, where ``model_name`` comes from the original GQCNN model, e.g. ``GQCNN-4.0-SUCTION``, ``GQ-Suction``, etc. + +Evaluate the OpenVINO™ Model with Example Policies +================================================== +Now the GQCNN model has been successfully converted into OpenVINO™ model. You may evaluate the GQCNN OpenVINO™ model with the example policy. Seek ``cfg/examples/replication/dex-net_.yaml`` for the below configure: :: + + # openvino: OFF|CPU|GPU|MYRIAD + openvino: OFF + +Toggle ``openvino`` among CPU, GPU, or MYRIAD. This configure specifies the target device type for the GQCNN inference to execute (supported device types listed in the ``introduction section`` of `Build Inference Engine`_). Then run the example policy in the same way as given in ``scripts/policies/run_all_dex-net__examples.sh``, e.g. :: + + $ python3 examples/policy.py GQCNN-4.0-SUCTION --depth_image data/examples/clutter/phoxi/dex-net_4.0/depth_0.npy --segmask data/examples/clutter/phoxi/dex-net_4.0/segmask_0.png --config_filename cfg/examples/replication/dex-net_4.0_suction.yaml --camera_intr data/calib/phoxi/phoxi.intr + +.. _Intel® OpenVINO™ Toolkit: https://github.com/opencv/dldt +.. _Build Inference Engine: https://github.com/opencv/dldt/blob/2019/inference-engine/README.md +.. _OpenVINO™ Docs: https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_Convert_Model_From_TensorFlow.html diff --git a/gqcnn/grasping/grasp_quality_function.py b/gqcnn/grasping/grasp_quality_function.py index 5d5637e2..61ac237b 100644 --- a/gqcnn/grasping/grasp_quality_function.py +++ b/gqcnn/grasping/grasp_quality_function.py @@ -928,7 +928,12 @@ def __init__(self, config): self._crop_width = config["crop_width"] # Init GQ-CNN - self._gqcnn = get_gqcnn_model().load(self._gqcnn_model_dir) + + self._gqcnn = get_gqcnn_model().load(self._gqcnn_model_dir) \ + if ("openvino" not in self._config or + self._config["openvino"] == "OFF") \ + else get_gqcnn_model("openvino").load( + self._gqcnn_model_dir, self._config["openvino"]) # Open Tensorflow session for gqcnn. self._gqcnn.open_session() diff --git a/gqcnn/model/__init__.py b/gqcnn/model/__init__.py index d120daa1..bd21beaa 100644 --- a/gqcnn/model/__init__.py +++ b/gqcnn/model/__init__.py @@ -60,6 +60,10 @@ def get_gqcnn_model(backend="tf", verbose=True): if backend == "tf": logger.info("Initializing GQ-CNN with Tensorflow as backend...") return GQCNNTF + elif backend == "openvino": + from .openvino import GQCNNOpenVINO + logger.info("Initializing GQ-CNN with OpenVINO as backend...") + return GQCNNOpenVINO else: raise ValueError("Invalid backend: {}".format(backend)) diff --git a/gqcnn/model/openvino/__init__.py b/gqcnn/model/openvino/__init__.py new file mode 100644 index 00000000..33eeeaa8 --- /dev/null +++ b/gqcnn/model/openvino/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) 2019 Intel Corporation. All Rights Reserved. + +GQ-CNN inference with OpenVINO. + +Author +------ +Sharron LIU +""" + +from .network_openvino import GQCNNOpenVINO + +__all__ = ["GQCNNOpenVINO"] diff --git a/gqcnn/model/openvino/network_openvino.py b/gqcnn/model/openvino/network_openvino.py new file mode 100644 index 00000000..f746bdb8 --- /dev/null +++ b/gqcnn/model/openvino/network_openvino.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) 2019 Intel Corporation. All Rights Reserved. + +GQ-CNN inference with OpenVINO. + +Author +------ +Sharron LIU +""" +from collections import OrderedDict +import json +import math +import os +import time +import numpy as np + +from autolab_core import Logger +from ...utils import (InputDepthMode, GQCNNFilenames) +from ..tf import GQCNNTF +from openvino.inference_engine import IENetwork, IECore + + +class GQCNNOpenVINO(GQCNNTF): + """GQ-CNN network implemented in OpenVINO.""" + + BatchSize = 64 + + def __init__(self, gqcnn_config, verbose=True, log_file=None): + """ + Parameters + ---------- + gqcnn_config : dict + Python dictionary of model configuration parameters. + verbose : bool + Whether or not to log model output to `stdout`. + log_file : str + If provided, model output will also be logged to this file. + """ + self._sess = None + # Set up logger. + self._logger = Logger.get_logger(self.__class__.__name__, + log_file=log_file, + silence=(not verbose), + global_log_file=verbose) + self._parse_config(gqcnn_config) + + @staticmethod + def load(model_dir, device, verbose=True, log_file=None): + """Instantiate a trained GQ-CNN for fine-tuning or inference. + + Parameters + ---------- + model_dir : str + Path to trained GQ-CNN model. + device : str + Device type for inference to execute CPU|GPU|MYRIAD + verbose : bool + Whether or not to log model output to `stdout`. + log_file : str + If provided, model output will also be logged to this file. + + Returns + ------- + :obj:`GQCNNOpenVINO` + Initialized GQ-CNN. + """ + # Load GQCNN config + config_file = os.path.join(model_dir, GQCNNFilenames.SAVED_CFG) + with open(config_file) as data_file: + train_config = json.load(data_file, object_pairs_hook=OrderedDict) + # Support for legacy configs. + try: + gqcnn_config = train_config["gqcnn"] + except KeyError: + gqcnn_config = train_config["gqcnn_config"] + gqcnn_config["debug"] = 0 + gqcnn_config["seed"] = 0 + # Legacy networks had no angular support. + gqcnn_config["num_angular_bins"] = 0 + # Legacy networks only supported depth integration through pose + # stream. + gqcnn_config["input_depth_mode"] = InputDepthMode.POSE_STREAM + + # Initialize OpenVINO network + gqcnn = GQCNNOpenVINO(gqcnn_config, verbose=verbose, log_file=log_file) + if (device == "MYRIAD"): # MYRIAD batch size force to 1 + gqcnn.set_batch_size(1) + else: + gqcnn.set_batch_size(64) + + # Initialize input tensors for prediction + gqcnn._input_im_arr = np.zeros((gqcnn._batch_size, gqcnn._im_height, + gqcnn._im_width, gqcnn._num_channels)) + gqcnn._input_pose_arr = np.zeros((gqcnn._batch_size, gqcnn._pose_dim)) + + # Initialize mean tensor and standard tensor + gqcnn.init_mean_and_std(model_dir) + + # Load OpenVINO network on specified device + gqcnn.load_openvino(model_dir, device) + + return gqcnn + + def open_session(self): + pass + + def close_session(self): + pass + + def load_openvino(self, model_dir, device): + self._ie = IECore() + # load OpenVINO executable network to device + model_path = os.path.split(model_dir) + model_xml = os.path.join(model_path[0], "OpenVINO", model_path[1]) + model_xml = os.path.join(model_xml, "FP16", + "inference_graph_frozen.xml") + model_bin = os.path.splitext(model_xml)[0] + ".bin" + self._vino_net = IENetwork(model_xml, model_bin) + self._vino_net.batch_size = self._batch_size + assert len(self._vino_net.inputs.keys()) == 2, "GQCNN two input nodes" + assert len(self._vino_net.outputs) == 1, "GQCNN one output node" + vino_inputs = iter(self._vino_net.inputs) + self._input_im = next(vino_inputs) + self._input_pose = next(vino_inputs) + self._output_blob = next(iter(self._vino_net.outputs)) + self._vino_exec_net = self._ie.load_network(network=self._vino_net, + device_name=device) + + def unload_openvino(self): + del self._vino_exec_net + del self._vino_net + del self._ie + + def predict_openvino(self, image_arr, pose_arr, verbose=False): + """ Predict a set of images in batches + Parameters + ---------- + image_arr : :obj:`tensorflow Tensor` + 4D Tensor of images to be predicted + pose_arr : :obj:`tensorflow Tensor` + 4D Tensor of poses to be predicted + """ + + # Get prediction start time. + start_time = time.time() + + if verbose: + self._logger.info("Predicting...") + + # Setup for prediction. + num_batches = math.ceil(image_arr.shape[0] / self._batch_size) + num_images = image_arr.shape[0] + num_poses = pose_arr.shape[0] + + output_arr = np.zeros( + [num_images] + + list(self._vino_net.outputs[self._output_blob].shape[1:])) + if num_images != num_poses: + raise ValueError("Must provide same number of images as poses!") + + # Predict in batches. + i = 0 + batch_idx = 0 + while i < num_images: + if verbose: + self._logger.info("Predicting batch {} of {}...{}".format( + batch_idx, num_batches, num_images)) + batch_idx += 1 + dim = min(self._batch_size, num_images - i) + cur_ind = i + end_ind = cur_ind + dim + + if self._input_depth_mode == InputDepthMode.POSE_STREAM: + self._input_im_arr[:dim, ...] = ( + image_arr[cur_ind:end_ind, ...] - + self._im_mean) / self._im_std + self._input_pose_arr[:dim, :] = ( + pose_arr[cur_ind:end_ind, :] - + self._pose_mean) / self._pose_std + elif self._input_depth_mode == InputDepthMode.SUB: + self._input_im_arr[:dim, ...] = image_arr[cur_ind:end_ind, ...] + self._input_pose_arr[:dim, :] = pose_arr[cur_ind:end_ind, :] + elif self._input_depth_mode == InputDepthMode.IM_ONLY: + self._input_im_arr[:dim, ...] = ( + image_arr[cur_ind:end_ind, ...] - + self._im_mean) / self._im_std + + n, c, h, w = self._vino_net.inputs[self._input_im].shape + input_im_arr = self._input_im_arr.reshape((n, c, h, w)) + res = self._vino_exec_net.infer( + inputs={ + self._input_im: input_im_arr, + self._input_pose: self._input_pose_arr + }) + + # Allocate output tensor. + output_arr[cur_ind:end_ind, :] = res[self._output_blob][:dim, :] + i = end_ind + + # Get total prediction time. + pred_time = time.time() - start_time + if verbose: + self._logger.info("Prediction took {} seconds.".format(pred_time)) + + return output_arr + + def predict(self, image_arr, pose_arr, verbose=False): + """Predict the probability of grasp success given a depth image and + gripper pose. + + Parameters + ---------- + image_arr : :obj:`numpy ndarray` + 4D tensor of depth images. + pose_arr : :obj:`numpy ndarray` + Tensor of gripper poses. + verbose : bool + Whether or not to log progress to stdout, useful to turn off during + training. + """ + return self.predict_openvino(image_arr, pose_arr, verbose=verbose) diff --git a/gqcnn/model/tf/network_tf.py b/gqcnn/model/tf/network_tf.py index 8b90d135..d2a86551 100644 --- a/gqcnn/model/tf/network_tf.py +++ b/gqcnn/model/tf/network_tf.py @@ -552,6 +552,19 @@ def initialize_network(self, if add_sigmoid: self.add_sigmoid_to_output() + # Freeze graph. Make it 'True' to freeze this graph + if False: + from tensorflow.python.framework import graph_io, graph_util + self.open_session() + frozen = graph_util.convert_variables_to_constants( + self._sess, self._sess.graph_def, ["softmax/Softmax"]) + graph_io.write_graph(frozen, + ".", + "inference_graph_frozen.pb", + as_text=False) + self.close_session() + self._logger.info("Wrote frozen graph") + # Create feed tensors for prediction. self._input_im_arr = np.zeros((self._batch_size, self._im_height, self._im_width, self._num_channels))