Skip to content

Commit

Permalink
Fix metadata logging for artifacts and models
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Jul 3, 2024
1 parent abd87be commit af5e09f
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions mlfoundry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
TFY_INTERNAL_JOB_RUN_NAME = os.getenv("TFY_INTERNAL_JOB_RUN_NAME")


def _drop_non_finite_values(dct: Dict[str, Any]) -> Dict[str, Any]:
sanitized = {}
for k, v in dct.items():
if isinstance(v, (int, float, np.integer, np.floating)):
if not math.isfinite(v):
logger.warning(f"Dropping non-finite value for key={k} value={v!r}")
continue
sanitized[k] = v
return sanitized


def is_mlfoundry_artifact(value: str):
# TODO (chiragjn): This should be made more strict
if value.startswith(MLFOUNDRY_ARTIFACT_PREFIX):
Expand Down Expand Up @@ -77,9 +88,7 @@ def log_model_to_mlfoundry(
"huggingface_model_url": f"https://huggingface.co/{hf_hub_model_id}",
}
)
metadata = {
k: v for k, v in metadata.items() if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(v)
}
metadata = _drop_non_finite_values(metadata)
run.log_model(
name=model_name,
model_file_or_folder=model_dir,
Expand Down Expand Up @@ -180,12 +189,7 @@ def on_save(self, args, state, control, **kwargs):
for log in state.log_history:
if isinstance(log, dict) and log.get("step") == state.global_step:
metadata = log.copy()

metadata = {
k: v
for k, v in metadata.items()
if isinstance(v, (int, float, np.integer, np.floating)) and math.isfinite(v)
}
metadata = _drop_non_finite_values(metadata)
self._run.log_artifact(
name=self._checkpoint_artifact_name,
artifact_paths=[(artifact_path,)],
Expand Down

0 comments on commit af5e09f

Please sign in to comment.