-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_trainer.py
89 lines (73 loc) · 3.19 KB
/
model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# {% include 'template/license_header' %}
from typing_extensions import Annotated
import mlflow
import pandas as pd
from sklearn.base import ClassifierMixin
from zenml import ArtifactConfig, step, get_step_context
from zenml.client import Client
from zenml.integrations.mlflow.experiment_trackers import MLFlowExperimentTracker
from zenml.integrations.mlflow.steps.mlflow_registry import mlflow_register_model_step
from zenml.logger import get_logger
logger = get_logger(__name__)
experiment_tracker = Client().active_stack.experiment_tracker
if not experiment_tracker or not isinstance(
experiment_tracker, MLFlowExperimentTracker
):
raise RuntimeError(
"Your active stack needs to contain a MLFlow experiment tracker for "
"this example to work."
)
@step(experiment_tracker=experiment_tracker.name)
def model_trainer(
dataset_trn: pd.DataFrame,
model: ClassifierMixin,
target: str,
name: str,
) -> Annotated[ClassifierMixin, ArtifactConfig(name="model", is_model_artifact=True)]:
"""Configure and train a model on the training dataset.
This is an example of a model training step that takes in a dataset artifact
previously loaded and pre-processed by other steps in your pipeline, then
configures and trains a model on it. The model is then returned as a step
output artifact.
Model training steps should have caching disabled if they are not
deterministic (i.e. if the model training involve some random processes
like initializing weights or shuffling data that are not controlled by
setting a fixed random seed). This example step ensures the outcome is
deterministic by initializing the model with a fixed random seed.
This step is parameterized to configure the step independently of the step code,
before running it in a pipeline. In this example, the step can be configured
to use a different model, change the random seed, or pass different
hyperparameters to the model constructor. See the documentation for more
information:
https://docs.zenml.io/how-to/build-pipelines/use-pipeline-step-parameters
Args:
dataset_trn: The preprocessed train dataset.
model: The model instance to train.
target: Name of target columns in dataset.
name: The name of the model.
Returns:
The trained model artifact.
"""
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
# Initialize the model with the hyperparameters indicated in the step
# parameters and train it on the training set.
logger.info(f"Training model {model}...")
mlflow.sklearn.autolog()
model.fit(
dataset_trn.drop(columns=[target]),
dataset_trn[target],
)
# register mlflow model
mlflow_register_model_step.entrypoint(
model,
name=name,
)
# keep track of mlflow version for future use
model_registry = Client().active_stack.model_registry
if model_registry:
version = model_registry.get_latest_model_version(name=name, stage=None)
if version:
model_ = get_step_context().model
model_.log_metadata({"model_registry_version": version.version})
### YOUR CODE ENDS HERE ###
return model