diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index f1fcca194e9ef5..78c4a4643f5228 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -28,6 +28,10 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" +#if defined(MKLDNN_TF_THREADING) +#include "mkldnn.h" +#endif + namespace tensorflow { /* static */ @@ -131,6 +135,9 @@ LocalDevice::LocalDevice(const SessionOptions& options, tp_info = owned_tp_info_.get(); } set_tensorflow_cpu_worker_threads(&tp_info->eigen_worker_threads_); +#if defined(MKLDNN_TF_THREADING) + mkldnn_set_tensorflow_thread_pool(tp_info->eigen_worker_threads_.workers); +#endif set_eigen_cpu_device(tp_info->eigen_device_.get()); } diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc index a055351337c7c6..a95f80f0a18180 100644 --- a/tensorflow/core/kernels/mkl_conv_ops_test.cc +++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/core/public/session.h" #if defined(INTEL_MKL_DNN_ONLY) -#include "third_party/intel_mkl_dnn/include/mkldnn.h" #include "tensorflow/core/util/mkl_util.h" #endif diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index f10d4a53eb5afc..96912b3bb1d568 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -40,6 +40,7 @@ load( load( "//third_party/mkl_dnn:build_defs.bzl", "if_mkl_open_source_only", + "if_mkl_dnn_uses_tf_threading", ) load( "//third_party/ngraph:build_defs.bzl", @@ -274,6 +275,7 @@ def tf_copts(android_optimization_level_override = "-O2", is_external = False): if_enable_mkl(["-DENABLE_MKL"]) + if_ngraph(["-DINTEL_NGRAPH=1"]) + if_mkl_lnx_x64(["-fopenmp"]) + + if_mkl_dnn_uses_tf_threading(["-fno-openmp -DMKLDNN_TF_THREADING"]) + if_android_arm(["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + if_ios_x86_64(["-msse4.1"]) + diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 9ddcc8f8b49206..bd32e91aea4cc6 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -118,11 +118,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "mkl_dnn", build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), - sha256 = "6bafc2c794961bd425683adb657b02ffc4296f0db6049c348d7c21b723f6ab15", - strip_prefix = "mkl-dnn-3439371cb7ca17456f0962f288dd17086aed0560", + sha256 = "bd406f4d9d004be270818449d010fe6cecd27af8cd465b5ec5e596905b45bb0d", + strip_prefix = "mkl-dnn-e842faaaf053c4e2926c3f857944c9837366c731", urls = [ - "https://mirror.bazel.build/github.com/rsdubtso/mkl-dnn/archive/3439371cb7ca17456f0962f288dd17086aed0560.tar.gz", - "https://github.com/rsdubtso/mkl-dnn/archive/3439371cb7ca17456f0962f288dd17086aed0560.tar.gz", + "https://mirror.bazel.build/github.com/rsdubtso/mkl-dnn/archive/e842faaaf053c4e2926c3f857944c9837366c731.tar.gz", + "https://github.com/rsdubtso/mkl-dnn/archive/e842faaaf053c4e2926c3f857944c9837366c731.tar.gz", ], ) diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD index 58ecda55e6eec1..5df94521d9f924 100644 --- a/third_party/mkl_dnn/BUILD +++ b/third_party/mkl_dnn/BUILD @@ -10,3 +10,12 @@ config_setting( }, visibility = ["//visibility:public"], ) + +config_setting( + name = "mkl_dnn_use_tf_threading", + define_values = { + "mkl_dnn_threading" : "tf", + "build_with_mkl_dnn_only": "true", + }, + visibility = ["//visibility:public"], +) diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl index 6388f31971cada..b94dd94c2cf12b 100644 --- a/third_party/mkl_dnn/build_defs.bzl +++ b/third_party/mkl_dnn/build_defs.bzl @@ -1,3 +1,6 @@ +def clean_dep(dep): + return str(Label(dep)) + def if_mkl_open_source_only(if_true, if_false = []): """Shorthand for select()'ing on whether we're building with MKL-DNN open source lib only, without depending on MKL binary form. @@ -8,6 +11,12 @@ def if_mkl_open_source_only(if_true, if_false = []): """ return select({ - str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_true, + clean_dep("//third_party/mkl_dnn:build_with_mkl_dnn_only"): if_true, + "//conditions:default": if_false, + }) + +def if_mkl_dnn_uses_tf_threading(if_true, if_false = []): + return select({ + clean_dep("//third_party/mkl_dnn:mkl_dnn_use_tf_threading"): if_true, "//conditions:default": if_false, }) diff --git a/third_party/mkl_dnn/mkldnn.BUILD b/third_party/mkl_dnn/mkldnn.BUILD index 66f41447bebba7..b2bc98e441fde1 100644 --- a/third_party/mkl_dnn/mkldnn.BUILD +++ b/third_party/mkl_dnn/mkldnn.BUILD @@ -3,6 +3,7 @@ exports_files(["LICENSE"]) load( "@org_tensorflow//third_party/mkl_dnn:build_defs.bzl", "if_mkl_open_source_only", + "if_mkl_dnn_uses_tf_threading", ) config_setting( @@ -13,6 +14,25 @@ config_setting( }, ) +cc_library( + name = "mkl_dnn_mkl_deps", + deps = select({ + "@org_tensorflow//tensorflow:linux_x86_64": [ + "@mkl_linux//:mkl_headers", + "@mkl_linux//:mkl_libs_linux", + ], + "@org_tensorflow//tensorflow:darwin": [ + "@mkl_darwin//:mkl_headers", + "@mkl_darwin//:mkl_libs_darwin", + ], + "@org_tensorflow//tensorflow:windows": [ + "@mkl_windows//:mkl_headers", + "@mkl_windows//:mkl_libs_windows", + ], + "//conditions:default": [], + }), +) + cc_library( name = "mkl_dnn", srcs = glob([ @@ -36,7 +56,10 @@ cc_library( # dependency. ":clang_linux_x86_64": [], "//conditions:default": [], - }), + }) + if_mkl_dnn_uses_tf_threading([ + "-DMKLDNN_THR=MKLDNN_THR_TENSORFLOW", + "-fno-openmp", + ]), includes = [ "include", "src", @@ -47,21 +70,14 @@ cc_library( ], nocopts = "-fno-exceptions", visibility = ["//visibility:public"], - deps = select({ - "@org_tensorflow//tensorflow:linux_x86_64": [ - "@mkl_linux//:mkl_headers", - "@mkl_linux//:mkl_libs_linux", - ], - "@org_tensorflow//tensorflow:darwin": [ - "@mkl_darwin//:mkl_headers", - "@mkl_darwin//:mkl_libs_darwin", + deps = if_mkl_dnn_uses_tf_threading( + [ + "@eigen_archive//:eigen", + "@protobuf_archive//:protobuf_headers", + "@org_tensorflow//tensorflow/core:core_cpu_headers_lib", + "@org_tensorflow//tensorflow/core:framework_headers_lib", ], - "@org_tensorflow//tensorflow:windows": [ - "@mkl_windows//:mkl_headers", - "@mkl_windows//:mkl_libs_windows", - ], - "//conditions:default": [], - }), + ["mkl_dnn_mkl_deps"]) ) cc_library(