Skip to content

Commit

Permalink
Add control to skip logging final model
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Nov 19, 2024
1 parent f810b26 commit 53f6f63
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
4 changes: 3 additions & 1 deletion config-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,15 @@ cleanup_output_dir_on_start: False
dataset_type: chat # Can be completion | chat
drop_long_sequences: False
logging_dir: ./tensorboard_logs
truefoundry_ml_run_name: auto # type: string
save_model_on_interrupt: False
train_data_uri: null
truefoundry_ml_checkpoint_artifact_name: auto # type: string
truefoundry_ml_enable_reporting: False
truefoundry_ml_log_checkpoints: True
truefoundry_ml_log_gpu_metrics: False
truefoundry_ml_log_merged_model: True
truefoundry_ml_repo: null
truefoundry_ml_run_name: auto # type: string
val_data_uri: null

## Liger
Expand Down
7 changes: 6 additions & 1 deletion plugins/axolotl_truefoundry/axolotl_truefoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,12 @@ def __init__(
run: "MlFoundryRun",
log_checkpoints: bool = True,
checkpoint_artifact_name: Optional[str] = None,
log_gpu_metrics: bool = False,
):
self._run = run
self._checkpoint_artifact_name = checkpoint_artifact_name
self._log_checkpoints = log_checkpoints
self._log_gpu_metrics = log_gpu_metrics

if not self._checkpoint_artifact_name:
logger.warning("checkpoint_artifact_name not passed. Checkpoints will not be logged to MLFoundry")
Expand All @@ -115,7 +117,7 @@ def on_log(self, args, state, control, logs, model=None, **kwargs):

metrics = {}
for k, v in logs.items():
if k.startswith("system/gpu"):
if k.startswith("system/gpu") and not self._log_gpu_metrics:
continue
if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(v):
metrics[k] = v
Expand Down Expand Up @@ -184,6 +186,8 @@ class TruefoundryMLPluginArgs(BaseModel):
truefoundry_ml_run_name: Optional[str] = None
truefoundry_ml_log_checkpoints: bool = True
truefoundry_ml_checkpoint_artifact_name: Optional[str] = None
truefoundry_ml_log_merged_model: bool = True
truefoundry_ml_log_gpu_metrics: bool = False

cleanup_output_dir_on_start: bool = False
logging_dir: str = "./tensorboard_logs"
Expand All @@ -209,6 +213,7 @@ def add_callbacks_post_trainer(self, cfg: TruefoundryMLPluginArgs, trainer: Trai
run=run,
log_checkpoints=cfg.truefoundry_ml_log_checkpoints,
checkpoint_artifact_name=cfg.truefoundry_ml_checkpoint_artifact_name,
log_gpu_metrics=cfg.truefoundry_ml_log_gpu_metrics,
)
extra_metrics_cb = ExtraMetricsCallback()
tensorboard_cb_idx = None
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs):
if os.path.exists(readme_path):
shutil.copy2(readme_path, os.path.join(model_dir, "README.md"))
logger.info(f"Merged model has been saved to {model_dir}")
if cfg.truefoundry_ml_enable_reporting is True:
if cfg.truefoundry_ml_enable_reporting is True and cfg.truefoundry_ml_log_merged_model is True:
*_, model_name = cfg.base_model.rsplit("/", 1)
model_name = "-".join(["finetuned", model_name, timestamp])
model_name = sanitize_name(model_name)
Expand Down

0 comments on commit 53f6f63

Please sign in to comment.