-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmlflow_helper.py
135 lines (108 loc) · 4.75 KB
/
mlflow_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import sys
import io
import mlflow
import subprocess
from nbconvert import HTMLExporter
import nbformat
class MLFlowLogger:
"""
A logger class for MLFlow to manage experiments and log parameters, metrics, and artifacts.
Attributes:
experiment_name (str): The name of the MLFlow experiment.
run_name (str|None): The name of the current MLFlow run.
"""
def __init__(self, experiment_name: str = "Default", run_name: str | None = None):
"""
Initializes the MLFlowLogger with specified experiment and run names.
Sets up MLFlow tracking and S3 endpoint environment variables, and ensures AWS credentials are configured.
Args:
experiment_name (str, optional): Name of the experiment. Defaults to "Default".
run_name (str | None, optional): Name of the run. Defaults to None.
"""
self.MLFLOW_SERVER_URL = "MLFLOW_SERVER_URL"
# mlflow credentials
self._user = "MLFLOW_USER"
self._password= "MLFLOW_PASSWORD"
# s3 endpoint for artifacts
self._s3_endpoint = "URL_S3_ENDPOINT"
self._aws_access_key_id = "minio"
self._aws_secret_access_key = "_aws_secret_access_key"
self._bucket_name = "_bucket_name"
self.set_env_variables()
self.experiment_name = experiment_name
self.run_name = run_name
def set_env_variables(self):
"""
Sets the necessary environment variables for MLFlow and AWS S3 integration.
Checks if AWS credentials file exists and creates it if not present.
"""
os.environ["MLFLOW_TRACKING_USERNAME"] = self._user
os.environ["MLFLOW_TRACKING_PASSWORD"] = self._password
os.environ["MLFLOW_S3_ENDPOINT_URL"] = self._s3_endpoint
os.environ["AWS_BUCKET_NAME"] = self._bucket_name
# check if credentials for artifacts are set
os.makedirs(os.path.expanduser("~/.aws"), exist_ok=True)
if not os.path.isfile(os.path.expanduser("~/.aws/credentials")):
content = f"""
[default]
aws_access_key_id={self._aws_access_key_id}
aws_secret_access_key={self._aws_secret_access_key}
"""
with open(os.path.expanduser("~/.aws/credentials"), "w") as f:
f.write(content)
@staticmethod
def log_commit_hash(logger = None, run_id: str | None = None):
"""
Logs the current Git commit hash to MLFlow.
Args:
logger (MLFlowLogger, optional): An instance of MLFlowLogger. Defaults to None.
run_id (str | None, optional): The run ID for logging. Defaults to None.
"""
# Get the current git commit hash
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
# Log the commit hash
if logger is None:
mlflow.log_param("git_commit_hash", commit_hash)
else:
logger.experiment.log_param(run_id=run_id, key="git_commit_hash", value=commit_hash)
@staticmethod
def log_notebook(filename: str, artifact_name: str = "notebook.html", logger = None, run_id: str | None = None):
"""
Converts a Jupyter Notebook to HTML and logs it as an artifact in MLFlow.
Args:
filename (str): The filename of the Jupyter Notebook.
artifact_name (str, optional): The name for the logged artifact. Defaults to "notebook.html".
logger (MLFlowLogger, optional): An instance of MLFlowLogger. Defaults to None.
run_id (str | None, optional): The run ID for logging. Defaults to None.
"""
# Load the current notebook
current_notebook = nbformat.read(open(filename), as_version=4)
# Export the notebook to HTML
exporter = HTMLExporter()
body, _ = exporter.from_notebook_node(current_notebook)
# Save the HTML to a file
with open(artifact_name, "w") as file:
file.write(body)
if logger is None:
mlflow.log_artifact(artifact_name, artifact_path="notebook")
else:
logger.experiment.log_artifact(run_id=run_id, local_path=artifact_name, artifact_path="notebook")
os.remove(artifact_name)
def start_run(self):
"""
Starts an MLFlow run, setting up the tracking URI and experiment.
"""
mlflow.set_tracking_uri(self.MLFLOW_SERVER_URL)
mlflow.set_experiment(self.experiment_name)
mlflow.start_run()
def end_run(self):
"""
Ends the current MLFlow run.
"""
mlflow.end_run()
if __name__ == "__main__":
mlflow_logger = MLFlowLogger()
mlflow_logger.start_run()
mlflow_logger.log_commit_hash()
mlflow_logger.end_run()