Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add HMMClassifier.fit multiprocessing #259

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/sections/models/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ The following models provided by Sequentia all support variable length sequences
| | | | +----------+------------+
| | | | | Training | Prediction |
+=========================+==============================+================+===============+==============+==========+============+
| :class:`.HMMClassifier` | :class:`.GaussianMixtureHMM` | Classification | Real | ✔ | | ✔ |
| :class:`.HMMClassifier` | :class:`.GaussianMixtureHMM` | Classification | Real | ✔ | | ✔ |
| +------------------------------+----------------+---------------+--------------+----------+------------+
| | :class:`.CategoricalHMM` | Classification | Categorical | ✗ | | ✔ |
| | :class:`.CategoricalHMM` | Classification | Categorical | ✗ | | ✔ |
+-------------------------+------------------------------+----------------+---------------+--------------+----------+------------+
| :class:`.KNNRegressor` | Regression | Real | ✔ | N/A | ✔ |
+--------------------------------------------------------+----------------+---------------+--------------+----------+------------+
Expand Down
5 changes: 3 additions & 2 deletions sequentia/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def param_grid(**kwargs: list[t.Any]) -> list[dict[str, t.Any]]:
settings for :class:`.GaussianMixtureHMM`, which is a nested model
specified in the constructor of a :class:`.HMMClassifier`. ::

from sklearn.preprocessing import Pipeline, minmax_scale
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import minmax_scale

from sequenta.enums import PriorMode, CovarianceMode, TopologyMode
from sequentia.enums import PriorMode, CovarianceMode, TopologyMode
from sequentia.models import HMMClassifier, GaussianMixtureHMM
from sequentia.preprocessing import IndependentFunctionTransformer
from sequentia.model_selection import GridSearchCV, StratifiedKFold
Expand Down
20 changes: 18 additions & 2 deletions sequentia/models/hmm/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,24 @@ def fit(
lengths=lengths,
classes=self.classes_,
)
for X_c, lengths_c, c in dataset.iter_by_class():
self.models[c].fit(X_c, lengths=lengths_c)

# get number of jobs
n_jobs = _multiprocessing.effective_n_jobs(
self.n_jobs, x=self.classes_
)

# fit models in parallel
self.models = dict(
zip(
self.classes_,
joblib.Parallel(n_jobs=n_jobs, max_nbytes=None)(
joblib.delayed(self.models[c].fit)(
X_c, lengths=lengths_c
)
for X_c, lengths_c, c in dataset.iter_by_class()
),
)
)

# Set class priors
models: t.Iterable[int, variants.BaseHMM] = self.models.items()
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_models/hmm/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,18 @@ def assert_fit(clf: BaseHMM):
],
)
@pytest.mark.parametrize("fit_mode", list(FitMode))
@pytest.mark.parametrize("n_jobs", [1, -1])
def test_classifier_e2e(
request: SubRequest,
helpers: t.Any,
model: BaseHMM,
dataset: SequentialDataset,
prior: enums.PriorMode | dict[int, float],
fit_mode: FitMode,
n_jobs: int,
random_state: np.random.RandomState,
) -> None:
clf = HMMClassifier(prior=prior)
clf = HMMClassifier(prior=prior, n_jobs=n_jobs)
clf.add_models({i: copy.deepcopy(model) for i in range(n_classes)})

assert clf.prior == prior
Expand All @@ -156,6 +158,7 @@ def test_classifier_e2e(
variant=type(model),
model_kwargs=model.get_params(),
prior=prior,
n_jobs=n_jobs,
)
clf.fit(**train.X_y_lengths)

Expand Down
Loading