Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
PoC for using MKL-DNN with TensorFlow thread pool
Browse files Browse the repository at this point in the history
The .bazelrc updated:
- TF threading is now default for MKL-DNN
- linking with binary MKL is now disabled by default

Limitations:
- only a single session without inter-op parallelism is supported
- XLA not covered
  • Loading branch information
Roman Dubtsov committed May 31, 2019
1 parent e215e8e commit a9b0a6a
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ build --define framework_shared_object=true
# environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define build_with_mkl_dnn_only=true
build:mkl --define mkl_dnn_threading=tf
build:mkl -c opt

# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_open_source_only --define mkl_dnn_threading=tf

build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
build:download_clang --define=using_clang=true
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/common_runtime/local_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/env_var.h"

#if defined(MKLDNN_TF_THREADING)
#include "mkldnn.h"
#endif

namespace tensorflow {
namespace {
class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
Expand Down Expand Up @@ -185,6 +189,9 @@ LocalDevice::LocalDevice(const SessionOptions& options,
tp_info = owned_tp_info_.get();
}
set_tensorflow_cpu_worker_threads(&tp_info->eigen_worker_threads_);
#if defined(MKLDNN_TF_THREADING)
mkldnn_set_tensorflow_thread_pool(tp_info->eigen_worker_threads_.workers);
#endif
set_eigen_cpu_device(tp_info->eigen_device_.get());
}

Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/kernels/mkl_conv_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/core/public/session.h"

#if defined(INTEL_MKL_DNN_ONLY)
#include "third_party/intel_mkl_dnn/include/mkldnn.h"
#include "tensorflow/core/util/mkl_util.h"
#endif

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/tensorflow.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ load(
load(
"//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
"if_mkl_dnn_uses_tf_threading",
)
load(
"//third_party/ngraph:build_defs.bzl",
Expand Down Expand Up @@ -295,6 +296,7 @@ def tf_copts(android_optimization_level_override = "-O2", is_external = False):
if_enable_mkl(["-DENABLE_MKL"]) +
if_ngraph(["-DINTEL_NGRAPH=1"]) +
if_mkl_lnx_x64(["-fopenmp"]) +
if_mkl_dnn_uses_tf_threading(["-fno-openmp -DMKLDNN_TF_THREADING"]) +
if_android_arm(["-mfpu=neon"]) +
if_linux_x86_64(["-msse3"]) +
if_ios_x86_64(["-msse4.1"]) +
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn",
build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
sha256 = "38a1c02104ee9f630c1ad68164119cd58ad0aaf59e04ccbe7bd5781add7bfbea",
strip_prefix = "mkl-dnn-0.18",
sha256 = "ee1a20ccfd9d7f6c2402a9ec9c5057e1d001296d26960e3b7908ade1fc68450b",
strip_prefix = "mkl-dnn-6e6c6c5647f706f43a07a2d39d157b0f0b3f8d44",
urls = [
"http://mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v0.18.tar.gz",
"https://github.com/intel/mkl-dnn/archive/v0.18.tar.gz",
"https://mirror.bazel.build/github.com/rsdubtso/mkl-dnn/archive/6e6c6c5647f706f43a07a2d39d157b0f0b3f8d44.tar.gz",
"https://github.com/rsdubtso/mkl-dnn/archive/6e6c6c5647f706f43a07a2d39d157b0f0b3f8d44.tar.gz",
],
)

Expand Down
9 changes: 9 additions & 0 deletions third_party/mkl_dnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,12 @@ config_setting(
},
visibility = ["//visibility:public"],
)

config_setting(
name = "mkl_dnn_use_tf_threading",
define_values = {
"mkl_dnn_threading" : "tf",
"build_with_mkl_dnn_only": "true",
},
visibility = ["//visibility:public"],
)
11 changes: 10 additions & 1 deletion third_party/mkl_dnn/build_defs.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
def clean_dep(dep):
return str(Label(dep))

def if_mkl_open_source_only(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with
MKL-DNN open source lib only, without depending on MKL binary form.
Expand All @@ -8,6 +11,12 @@ def if_mkl_open_source_only(if_true, if_false = []):
"""
return select({
str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_true,
clean_dep("//third_party/mkl_dnn:build_with_mkl_dnn_only"): if_true,
"//conditions:default": if_false,
})

def if_mkl_dnn_uses_tf_threading(if_true, if_false = []):
return select({
clean_dep("//third_party/mkl_dnn:mkl_dnn_use_tf_threading"): if_true,
"//conditions:default": if_false,
})
46 changes: 31 additions & 15 deletions third_party/mkl_dnn/mkldnn.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ exports_files(["LICENSE"])
load(
"@org_tensorflow//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
"if_mkl_dnn_uses_tf_threading",
)
load(
"@org_tensorflow//third_party:common.bzl",
Expand Down Expand Up @@ -37,6 +38,25 @@ template_rule(
},
)

cc_library(
name = "mkl_dnn_mkl_deps",
deps = select({
"@org_tensorflow//tensorflow:linux_x86_64": [
"@mkl_linux//:mkl_headers",
"@mkl_linux//:mkl_libs_linux",
],
"@org_tensorflow//tensorflow:macos": [
"@mkl_macos//:mkl_headers",
"@mkl_macos//:mkl_libs_macos",
],
"@org_tensorflow//tensorflow:windows": [
"@mkl_windows//:mkl_headers",
"@mkl_windows//:mkl_libs_windows",
],
"//conditions:default": [],
}),
)

cc_library(
name = "mkl_dnn",
srcs = glob([
Expand Down Expand Up @@ -70,7 +90,10 @@ cc_library(
# dependency.
":clang_linux_x86_64": [],
"//conditions:default": [],
}),
}) + if_mkl_dnn_uses_tf_threading([
"-DMKLDNN_THR=MKLDNN_THR_TENSORFLOW",
"-fno-openmp",
]),
includes = [
"include",
"src",
Expand All @@ -81,21 +104,14 @@ cc_library(
],
nocopts = "-fno-exceptions",
visibility = ["//visibility:public"],
deps = select({
"@org_tensorflow//tensorflow:linux_x86_64": [
"@mkl_linux//:mkl_headers",
"@mkl_linux//:mkl_libs_linux",
deps = if_mkl_dnn_uses_tf_threading(
[
"@eigen_archive//:eigen",
"@protobuf_archive//:protobuf_headers",
"@org_tensorflow//tensorflow/core:core_cpu_headers_lib",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
],
"@org_tensorflow//tensorflow:macos": [
"@mkl_darwin//:mkl_headers",
"@mkl_darwin//:mkl_libs_darwin",
],
"@org_tensorflow//tensorflow:windows": [
"@mkl_windows//:mkl_headers",
"@mkl_windows//:mkl_libs_windows",
],
"//conditions:default": [],
}),
["mkl_dnn_mkl_deps"])
)

cc_library(
Expand Down

0 comments on commit a9b0a6a

Please sign in to comment.