Skip to content

Commit

Permalink
make mlflow optional
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Feb 22, 2024
1 parent a359579 commit a195453
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 33 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ hf_transfer
colorama
numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.1
scipy
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,8 @@ def parse_requirements():
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
},
)
12 changes: 10 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import abc
import importlib
import importlib.util
import logging
import math
import sys
Expand Down Expand Up @@ -34,7 +35,6 @@
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoMlflowCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
Expand Down Expand Up @@ -62,6 +62,10 @@
LOG = logging.getLogger("axolotl.core.trainer_builder")


def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None


def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
Expand Down Expand Up @@ -648,7 +652,11 @@ def get_callbacks(self):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow:
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)

callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, Dict, List

import evaluate
import mlflow
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -42,8 +41,8 @@
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments

LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")


class EvalFirstStepCallback(
Expand Down Expand Up @@ -756,31 +755,3 @@ def on_train_begin(
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control


class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
"""Callback to save axolotl config to mlflow"""

def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path

def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control
44 changes: 44 additions & 0 deletions src/axolotl/utils/callbacks/mlflow_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""MLFlow module for trainer callbacks"""
import logging
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING

import mlflow
from transformers import TrainerCallback, TrainerControl, TrainerState

from axolotl.utils.distributed import is_main_process

if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments

LOG = logging.getLogger("axolotl.callbacks")


class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""

def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path

def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

0 comments on commit a195453

Please sign in to comment.