Skip to content

Commit

Permalink
Merge pull request #2430 from cta-observatory/n_jobs
Browse files Browse the repository at this point in the history
Allow setting n_jobs on tool invocation
  • Loading branch information
maxnoe authored Nov 23, 2023
2 parents 0eb4245 + 0372d11 commit 4a45dcd
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 16 deletions.
10 changes: 8 additions & 2 deletions ctapipe/reco/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ctapipe.containers import ArrayEventContainer, TelescopeImpactParameterContainer
from ctapipe.core import Provenance, QualityQuery, TelescopeComponent
from ctapipe.core.traits import List
from ctapipe.core.traits import Integer, List

from ..compat import StrEnum
from ..coordinates import shower_impact_distance
Expand Down Expand Up @@ -70,9 +70,15 @@ class Reconstructor(TelescopeComponent):
algorithms should inherit from
"""

#: ctapipe_rco entry points may provide Reconstructor implementations
#: ctapipe_reco entry points may provide Reconstructor implementations
plugin_entry_point = "ctapipe_reco"

n_jobs = Integer(
default_value=None,
allow_none=True,
help="Number of threads to use for the reconstruction if supported by the reconstructor.",
).tag(config=True)

def __init__(self, subarray, **kwargs):
super().__init__(subarray=subarray, **kwargs)
self.quality_query = StereoQualityQuery(parent=self)
Expand Down
38 changes: 32 additions & 6 deletions ctapipe/reco/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.utils import all_estimators
from tqdm import tqdm
from traitlets import TraitError
from traitlets import TraitError, observe

from ctapipe.exceptions import TooFewEvents

Expand Down Expand Up @@ -103,7 +103,7 @@ class SKLearnReconstructor(Reconstructor):
help="If given, load serialized model from this path",
).tag(config=True)

def __init__(self, subarray=None, models=None, **kwargs):
def __init__(self, subarray=None, models=None, n_jobs=None, **kwargs):
# Run the Component __init__ first to handle the configuration
# and make `self.load_path` available
Component.__init__(self, **kwargs)
Expand Down Expand Up @@ -199,7 +199,10 @@ def instrument_table(self):
return QTable(self.subarray.to_table("joined"))

def _new_model(self):
return SUPPORTED_MODELS[self.model_cls](**self.model_config)
cfg = self.model_config
if self.n_jobs:
cfg["n_jobs"] = self.n_jobs
return SUPPORTED_MODELS[self.model_cls](**cfg)

def _table_to_y(self, table, mask=None):
"""
Expand All @@ -222,6 +225,15 @@ def fit(self, key, table):
y = self._table_to_y(table, mask=valid)
self._models[key].fit(X, y)

@observe("n_jobs")
def _set_n_jobs(self, n_jobs):
"""
Update n_jobs of all associated models.
"""
if hasattr(self, "_models"):
for model in self._models.values():
model.n_jobs = n_jobs.new


class SKLearnRegressionReconstructor(SKLearnReconstructor):
"""
Expand Down Expand Up @@ -562,7 +574,6 @@ def __init__(self, subarray=None, models=None, **kwargs):

# to verify settings
self._new_models()

self._models = {} if models is None else models
self.unit = None
self.stereo_combiner = StereoCombiner.from_name(
Expand All @@ -584,8 +595,13 @@ def __init__(self, subarray=None, models=None, **kwargs):
self.subarray = subarray

def _new_models(self):
norm_regressor = SUPPORTED_REGRESSORS[self.norm_cls](**self.norm_config)
sign_classifier = SUPPORTED_CLASSIFIERS[self.sign_cls](**self.sign_config)
norm_cfg = self.norm_config
sign_cfg = self.sign_config
if self.n_jobs:
norm_cfg["n_jobs"] = self.n_jobs
sign_cfg["n_jobs"] = self.n_jobs
norm_regressor = SUPPORTED_REGRESSORS[self.norm_cls](**norm_cfg)
sign_classifier = SUPPORTED_CLASSIFIERS[self.sign_cls](**sign_cfg)
return norm_regressor, sign_classifier

def _table_to_y(self, table, mask=None):
Expand Down Expand Up @@ -803,6 +819,16 @@ def predict_table(self, key, table: Table) -> Dict[ReconstructionProperty, Table
ReconstructionProperty.GEOMETRY: altaz_result,
}

@observe("n_jobs")
def _set_n_jobs(self, n_jobs):
"""
Update n_jobs of all associated models.
"""
if hasattr(self, "_models"):
for (disp, sign) in self._models.values():
disp.n_jobs = n_jobs.new
sign.n_jobs = n_jobs.new


class CrossValidator(Component):
"""Class to train sklearn based reconstructors in a cross validation"""
Expand Down
44 changes: 44 additions & 0 deletions ctapipe/reco/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ctapipe.core import Component
from ctapipe.reco import EnergyRegressor, ParticleClassifier
from ctapipe.reco.reconstructor import ReconstructionProperty
from ctapipe.reco.sklearn import DispReconstructor

KEY = "LST_LST_LSTCam"

Expand Down Expand Up @@ -166,6 +167,49 @@ def test_regressor_single_event(model_cls, example_table, example_subarray):
assert valid[0] == False


def test_set_n_jobs(example_subarray):
config = Config(
{
"EnergyRegressor": {
"model_cls": "RandomForestRegressor",
"model_config": {"n_estimators": 20, "max_depth": 15, "n_jobs": -1},
}
}
)
regressor = EnergyRegressor(
example_subarray,
config=config,
)

regressor._models["telescope"] = regressor._new_model()
assert regressor._models["telescope"].n_jobs == -1
regressor.n_jobs = 42
assert regressor._models["telescope"].n_jobs == 42

# DISP has two models per telescope, check that aswell
config = Config(
{
"DispReconstructor": {
"norm_cls": "RandomForestRegressor",
"norm_config": {"n_estimators": 20, "max_depth": 15, "n_jobs": -1},
"sign_cls": "RandomForestClassifier",
"sign_config": {"n_estimators": 20, "max_depth": 15, "n_jobs": -1},
}
}
)
disp = DispReconstructor(
example_subarray,
config=config,
)

disp._models["telescope"] = disp._new_models()
assert disp._models["telescope"][0].n_jobs == -1
assert disp._models["telescope"][1].n_jobs == -1
disp.n_jobs = 42
assert disp._models["telescope"][0].n_jobs == 42
assert disp._models["telescope"][1].n_jobs == 42


@pytest.mark.parametrize(
"model_cls", ["KNeighborsClassifier", "RandomForestClassifier"]
)
Expand Down
5 changes: 3 additions & 2 deletions ctapipe/resources/train_disp_reconstructor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ TrainDispReconstructor:
# prefix: # Add a prefix of the output here, if you want to apply multiple
# DispReconstructors on the same file (e.g. for comparing different settings)

# How many cores to use. Overwrites model config
n_jobs: -1

# All regression algorithms in scikit-learn are supported
# (https://scikit-learn.org/stable/modules/classes.html)
norm_cls: ExtraTreesRegressor
norm_config:
n_estimators: 10
max_depth: 10
n_jobs: -1

log_target: True

Expand All @@ -35,7 +37,6 @@ TrainDispReconstructor:
sign_config:
n_estimators: 10
max_depth: 10
n_jobs: -1

QualityQuery: # Event Selection performed before training the models
quality_criteria:
Expand Down
2 changes: 1 addition & 1 deletion ctapipe/resources/train_energy_regressor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ TrainEnergyRegressor:
# EnergyRegressors on the same file (e.g. for comparing different settings)

log_target: True
n_jobs: -1

# All regression algorithms in scikit-learn are supported
# (https://scikit-learn.org/stable/modules/classes.html)
model_cls: ExtraTreesRegressor
model_config:
n_estimators: 10
max_depth: 10
n_jobs: -1

QualityQuery: # Event Selection performed before training the models
quality_criteria:
Expand Down
4 changes: 3 additions & 1 deletion ctapipe/resources/train_particle_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ TrainParticleClassifier:
# prefix: # Add a prefix of the output here, if you want to apply multiple
# ParticleClassifiers on the same file (e.g. for comparing different settings)

# How many cores to use. Overwrites model config
n_jobs: -1

# All classification algorithms in scikit-learn are supported
# (https://scikit-learn.org/stable/modules/classes.html)
model_cls: ExtraTreesClassifier
model_config:
n_estimators: 10
max_depth: 10
n_jobs: -1

QualityQuery: # Event Selection performed before training the models
quality_criteria:
Expand Down
17 changes: 13 additions & 4 deletions ctapipe/tools/apply_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ class ApplyModels(Tool):
help="How many subarray events to load at once for making predictions.",
).tag(config=True)

n_jobs = Integer(
default_value=None,
allow_none=True,
help="Number of threads to use for the reconstruction. This overwrites the values in the config",
).tag(config=True)

progress_bar = Bool(
help="show progress bar during processing",
default_value=True,
Expand All @@ -77,6 +83,7 @@ class ApplyModels(Tool):
("i", "input"): "ApplyModels.input_url",
("r", "reconstructor"): "ApplyModels.reconstructor_paths",
("o", "output"): "ApplyModels.output_path",
"n-jobs": "ApplyModels.n_jobs",
"chunk-size": "ApplyModels.chunk_size",
}

Expand Down Expand Up @@ -146,10 +153,12 @@ def setup(self):
)
)

self._reconstructors = [
Reconstructor.read(path, parent=self, subarray=self.loader.subarray)
for path in self.reconstructor_paths
]
self._reconstructors = []
for path in self.reconstructor_paths:
r = Reconstructor.read(path, parent=self, subarray=self.loader.subarray)
if self.n_jobs:
r.n_jobs = self.n_jobs
self._reconstructors.append(r)

def start(self):
"""Apply models to input tables"""
Expand Down
9 changes: 9 additions & 0 deletions ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class TrainDispReconstructor(Tool):
default_value=0, help="Random seed for sampling and cross validation"
).tag(config=True)

n_jobs = Int(
default_value=None,
allow_none=True,
help="Number of threads to use for the reconstruction. This overwrites the values in the config of each reconstructor.",
).tag(config=True)

project_disp = Bool(
default_value=False,
help=(
Expand All @@ -80,6 +86,7 @@ class TrainDispReconstructor(Tool):
("i", "input"): "TableLoader.input_url",
("o", "output"): "TrainDispReconstructor.output_path",
"n-events": "TrainDispReconstructor.n_events",
"n-jobs": "DispReconstructor.n_jobs",
"cv-output": "CrossValidator.output_path",
}

Expand All @@ -103,6 +110,7 @@ def setup(self):
self.n_events.attach_subarray(self.loader.subarray)

self.models = DispReconstructor(self.loader.subarray, parent=self)

self.cross_validate = CrossValidator(parent=self, model_component=self.models)
self.rng = np.random.default_rng(self.random_seed)
self.check_output(self.output_path, self.cross_validate.output_path)
Expand Down Expand Up @@ -182,6 +190,7 @@ def finish(self):
Write-out trained models and cross-validation results.
"""
self.log.info("Writing output")
self.models.n_jobs = None
self.models.write(self.output_path, overwrite=self.overwrite)
if self.cross_validate.output_path:
self.cross_validate.write(overwrite=self.overwrite)
Expand Down
9 changes: 9 additions & 0 deletions ctapipe/tools/train_energy_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,18 @@ class TrainEnergyRegressor(Tool):
default_value=0, help="Random seed for sampling and cross validation"
).tag(config=True)

n_jobs = Int(
default_value=None,
allow_none=True,
help="Number of threads to use for the reconstruction. This overwrites the values in the config of each reconstructor.",
).tag(config=True)

aliases = {
("i", "input"): "TableLoader.input_url",
("o", "output"): "TrainEnergyRegressor.output_path",
"n-events": "TrainEnergyRegressor.n_events",
"chunk-size": "TrainEnergyRegressor.chunk_size",
"n-jobs": "EnergyRegressor.n_jobs",
"cv-output": "CrossValidator.output_path",
}

Expand All @@ -94,6 +101,7 @@ def setup(self):
self.n_events.attach_subarray(self.loader.subarray)

self.regressor = EnergyRegressor(self.loader.subarray, parent=self)
self.log.warning(f"{self.regressor._models}")
self.cross_validate = CrossValidator(
parent=self, model_component=self.regressor
)
Expand Down Expand Up @@ -137,6 +145,7 @@ def finish(self):
Write-out trained models and cross-validation results.
"""
self.log.info("Writing output")
self.regressor.n_jobs = None
self.regressor.write(self.output_path, overwrite=self.overwrite)
if self.cross_validate.output_path:
self.cross_validate.write(overwrite=self.overwrite)
Expand Down
8 changes: 8 additions & 0 deletions ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,18 @@ class TrainParticleClassifier(Tool):
help="Random number seed for sampling and the cross validation splitting",
).tag(config=True)

n_jobs = Int(
default_value=None,
allow_none=True,
help="Number of threads to use for the reconstruction. This overwrites the values in the config of each reconstructor.",
).tag(config=True)

aliases = {
"signal": "TrainParticleClassifier.input_url_signal",
"background": "TrainParticleClassifier.input_url_background",
"n-signal": "TrainParticleClassifier.n_signal",
"n-background": "TrainParticleClassifier.n_background",
"n-jobs": "ParticleClassifier.n_jobs",
("o", "output"): "TrainParticleClassifier.output_path",
"cv-output": "CrossValidator.output_path",
}
Expand Down Expand Up @@ -207,6 +214,7 @@ def finish(self):
Write-out trained models and cross-validation results.
"""
self.log.info("Writing output")
self.classifier.n_jobs = None
self.classifier.write(self.output_path, overwrite=self.overwrite)
self.signal_loader.close()
self.background_loader.close()
Expand Down
3 changes: 3 additions & 0 deletions docs/changes/2430.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Allow setting n_jobs on the command line for the
train_* and apply_models tools using a new ``n_jobs`` flag.
This temporarily overwrites any settings in the (model) config(s).

0 comments on commit 4a45dcd

Please sign in to comment.