Skip to content

Commit

Permalink
Update truefoundry to 0.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Dec 12, 2024
1 parent cccdc20 commit 63b69c2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion base-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ s3fs==2024.9.0
snowflake-connector-python[pandas]==3.12.3
torch==2.3.1+cu121
torchao==0.6.1+cu121
truefoundry==0.5.1rc6
truefoundry==0.5.1
unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git@9dc399a6b6625ee40835c5eab361426d3c5d4abb
24 changes: 14 additions & 10 deletions mlfoundry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
from huggingface_hub import scan_cache_dir
from truefoundry import ml as mlfoundry
from truefoundry import ml

logger = logging.getLogger("axolotl")

Expand Down Expand Up @@ -42,7 +42,7 @@ def download_mlfoundry_artifact(
overwrite: bool = False,
move_to: Optional[str] = None,
):
client = mlfoundry.get_client()
client = ml.get_client()
artifact_version = client.get_artifact_version_by_fqn(artifact_version_fqn)
os.makedirs(download_dir, exist_ok=True)
files_dir = artifact_version.download(download_dir, overwrite=overwrite)
Expand All @@ -52,7 +52,7 @@ def download_mlfoundry_artifact(


def log_model_to_mlfoundry(
run: mlfoundry.MlFoundryRun,
run: ml.MlFoundryRun,
model_name: str,
model_dir: str,
hf_hub_model_id: str,
Expand Down Expand Up @@ -94,7 +94,11 @@ def log_model_to_mlfoundry(
run.log_model(
name=model_name,
model_file_or_folder=model_dir,
framework=mlfoundry.ModelFramework.TRANSFORMERS,
framework=ml.TransformersFramework(
library_name="transformers", # type: ignore
pipeline_tag="text-generation",
base_model=hf_hub_model_id,
),
metadata=metadata,
step=step or 0,
)
Expand All @@ -103,11 +107,11 @@ def log_model_to_mlfoundry(
def get_latest_checkpoint_artifact_version_or_none(
ml_repo: str,
checkpoint_artifact_name: str,
) -> Optional[mlfoundry.ArtifactVersion]:
) -> Optional[ml.ArtifactVersion]:
# TODO (chiragjn): Reduce coupling with checkpointing, log lines are still related
latest_checkpoint_artifact = None
try:
client = mlfoundry.get_client()
client = ml.get_client()
artifact_versions = client.list_artifact_versions(ml_repo=ml_repo, name=checkpoint_artifact_name)
latest_checkpoint_artifact = next(artifact_versions)
except StopIteration:
Expand All @@ -123,10 +127,10 @@ def get_latest_checkpoint_artifact_version_or_none(

def get_checkpoint_artifact_version_with_step_or_none(
ml_repo: str, checkpoint_artifact_name: str, step: int
) -> Optional[mlfoundry.ArtifactVersion]:
) -> Optional[ml.ArtifactVersion]:
checkpoint_artifact_version_with_step = None
try:
client = mlfoundry.get_client()
client = ml.get_client()
artifact_versions = client.list_artifact_versions(ml_repo=ml_repo, name=checkpoint_artifact_name)
for artifact_version in artifact_versions:
if artifact_version.step == step:
Expand Down Expand Up @@ -159,7 +163,7 @@ def generate_run_name(model_id, seed: Optional[int] = None):
def get_or_create_run(ml_repo: str, run_name: str, auto_end: bool = False):
from truefoundry.ml.autogen.client.exceptions import NotFoundException

client = mlfoundry.get_client()
client = ml.get_client()
try:
run = client.get_run_by_name(ml_repo=ml_repo, run_name=run_name)
except Exception as e:
Expand All @@ -169,7 +173,7 @@ def get_or_create_run(ml_repo: str, run_name: str, auto_end: bool = False):
return run


def maybe_log_params_to_mlfoundry(run: mlfoundry.MlFoundryRun, params: Dict[str, Any]):
def maybe_log_params_to_mlfoundry(run: ml.MlFoundryRun, params: Dict[str, Any]):
if not params:
return
if run.get_params():
Expand Down
2 changes: 1 addition & 1 deletion plugins/axolotl_truefoundry/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "TrueFoundry plugin for Axolotl"
requires-python = ">=3.8.1,<4.0"
dependencies = [
"transformers>=4.0.0,<5",
"truefoundry==0.5.1rc6",
"truefoundry>=0.5.1,<0.6.0",
"pynvml>=11.0.0,<12",
"torch>=2.3.0,<2.4.0",
"pydantic>=2.0.0,<3",
Expand Down

0 comments on commit 63b69c2

Please sign in to comment.