From 99287222f22a2b0ece40cf19021f28a5e99e4006 Mon Sep 17 00:00:00 2001 From: KulikovNikita Date: Mon, 27 Feb 2023 13:12:25 +0000 Subject: [PATCH] Fixes in LinearRegression SPMD (#1195) --- onedal/datatypes/_data_conversion.py | 15 ++++----------- onedal/primitives/tree_visitor.cpp | 9 +++------ setup.py | 2 +- setup_sklearnex.py | 16 +++++++++++++++- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/onedal/datatypes/_data_conversion.py b/onedal/datatypes/_data_conversion.py index c078cbe4db..0aac322585 100644 --- a/onedal/datatypes/_data_conversion.py +++ b/onedal/datatypes/_data_conversion.py @@ -40,25 +40,18 @@ def to_table(*args): if _is_dpc_backend: import numpy as np - from ..common._spmd_policy import _SPMDDataParallelInteropPolicy - from ..common._policy import _HostInteropPolicy, _DataParallelInteropPolicy + from ..common._policy import _HostInteropPolicy def _convert_to_supported_impl(policy, *data): # CPUs support FP64 by default - is_host = isinstance(policy, _HostInteropPolicy) - no_dpcpp = not _is_dpc_backend - if is_host or no_dpcpp: + if isinstance(policy, _HostInteropPolicy): return data - # There is only one option of data parallel policy - is_dpcpp_policy = isinstance(policy, _DataParallelInteropPolicy) - is_spmd_policy = isinstance(policy, _SPMDDataParallelInteropPolicy) - assert is_spmd_policy or is_dpcpp_policy - + # It can be either SPMD or DPCPP policy device = policy._queue.sycl_device def convert_or_pass(x): - if x.dtype is not np.float32: + if x.dtype is np.float64: return x.astype(np.float32) else: return x diff --git a/onedal/primitives/tree_visitor.cpp b/onedal/primitives/tree_visitor.cpp index 1270a151f1..12e5311e33 100644 --- a/onedal/primitives/tree_visitor.cpp +++ b/onedal/primitives/tree_visitor.cpp @@ -26,9 +26,6 @@ #include #include -#include -#include - #define ONEDAL_PY_TERMINAL_NODE -1 #define ONEDAL_PY_NO_FEATURE -2 @@ -45,7 +42,7 @@ inline static const double get_nan64() { // equivalent for numpy arange template -std::vector arange(T start, T stop, T step = 1) { +inline std::vector arange(T start, T stop, T step = 1) { std::vector res; for (T i = start; i < stop; i += step) res.push_back(i); @@ -128,7 +125,7 @@ class node_visitor { template class to_sklearn_tree_object_visitor : public tree_state { public: - to_sklearn_tree_object_visitor(size_t _depth, + to_sklearn_tree_object_visitor(std::size_t _depth, std::size_t _n_nodes, std::size_t _n_leafs, std::size_t _max_n_classes); @@ -143,7 +140,7 @@ class to_sklearn_tree_object_visitor : public tree_state { }; template -to_sklearn_tree_object_visitor::to_sklearn_tree_object_visitor(size_t _depth, +to_sklearn_tree_object_visitor::to_sklearn_tree_object_visitor(std::size_t _depth, std::size_t _n_nodes, std::size_t _n_leafs, std::size_t _max_n_classes) diff --git a/setup.py b/setup.py index e3cb6f9e62..3e5ed8840b 100644 --- a/setup.py +++ b/setup.py @@ -82,7 +82,7 @@ except ImportError: dpctl_available = False -build_distribute = dpcpp and dpctl_available and not no_dist +build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN daal_lib_dir = lib_dir if (IS_MAC or os.path.isdir( diff --git a/setup_sklearnex.py b/setup_sklearnex.py index 82f0f6fd0d..f3278ce099 100755 --- a/setup_sklearnex.py +++ b/setup_sklearnex.py @@ -17,6 +17,7 @@ # System imports import os +import sys import time from setuptools import setup from scripts.version import get_onedal_version @@ -25,6 +26,19 @@ sklearnex_version = (os.environ["SKLEARNEX_VERSION"] if "SKLEARNEX_VERSION" in os.environ else time.strftime("%Y%m%d.%H%M%S")) +IS_WIN = False +IS_MAC = False +IS_LIN = False + +if 'linux' in sys.platform: + IS_LIN = True +elif sys.platform == 'darwin': + IS_MAC = True +elif sys.platform in ['win32', 'cygwin']: + IS_WIN = True +else: + assert False, sys.platform + ' not supported' + dal_root = os.environ.get('DALROOT') if dal_root is None: @@ -41,7 +55,7 @@ except ImportError: dpctl_available = False -build_distribute = dpcpp and dpctl_available and not no_dist +build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN ONEDAL_VERSION = get_onedal_version(dal_root)