Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Provide informative error message on bad type input to predict #35

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mrocklin
Copy link
Member

No description provided.

@mrocklin
Copy link
Member Author

mrocklin commented Feb 21, 2019 via email

@TomAugspurger
Copy link
Member

I suspect the test_sparse one is similar to what we ran into with dask. IIRC sparse changed to be stricter about not converting to dense.

No idea about the other ones unfortunately :/ Possibly something with pytest-xdist?

FWIW, I have a local (unpunished) branch called test-fixup with this diff

diff --git a/.circleci/config.yml b/.circleci/config.yml
index f1463079..72faf516 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -16,7 +16,7 @@ jobs:
             conda config --add channels conda-forge
             conda create -q -n test-environment python=${PYTHON}
             source activate test-environment
-            conda install -q coverage flake8 pytest pytest-cov pytest-xdist numpy pandas xgboost dask distributed scikit-learn sparse scipy
+            conda install -q coverage flake8 pytest pytest-cov numpy pandas xgboost dask distributed scikit-learn sparse scipy
             pip install -e .
             conda list test-environment
       - run:
diff --git a/dask_xgboost/core.py b/dask_xgboost/core.py
index 6bf29d78..c843a000 100644
--- a/dask_xgboost/core.py
+++ b/dask_xgboost/core.py
@@ -34,7 +34,7 @@ def parse_host_port(address):
     return host, port
 
 
-def start_tracker(host, n_workers):
+def start_tracker(host, n_workers, dask_scheduler=None):
     """ Start Rabit tracker """
     env = {'DMLC_NUM_WORKER': n_workers}
     rabit = RabitTracker(hostIP=host, nslave=n_workers)
@@ -45,6 +45,7 @@ def start_tracker(host, n_workers):
     thread = Thread(target=rabit.join)
     thread.daemon = True
     thread.start()
+    dask_scheduler.xgboost_thread = thread
     return env
 
 
@@ -155,6 +156,13 @@ def _train(client, params, data, labels, dmatrix_kwargs={}, **kwargs):
     num_class = params.get("num_class")
     if num_class:
         result.set_attr(num_class=str(num_class))
+
+    def wait_on_tracker_thread(dask_scheduler):
+        dask_scheduler.xgboost_thread.join()
+        del dask_scheduler.xgboost_thread
+
+    yield client.run_on_scheduler(wait_on_tracker_thread)
+
     raise gen.Return(result)
 
 
diff --git a/setup.cfg b/setup.cfg
index 2348f495..11894603 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,4 +5,4 @@ universal=1
 exclude = tests/data,docs,benchmarks,scripts
 
 [tool:pytest]
-addopts = -rsx -v -n 1 --boxed
+addopts = -rsx -v

Looking further, that looks like #29 (comment)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants