Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
respond to reviewer comments, implement url parse for artifact files
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 19, 2024
1 parent 3708f5c commit 3cd9c48
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/flamingo/integrations/wandb/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
from pathlib import Path
from typing import Any
from urllib.parse import ParseResult, urlparse

import wandb
from wandb.apis.public import Run as ApiRun
Expand All @@ -16,7 +17,16 @@ def wandb_init_from_config(
parameters: BaseFlamingoConfig | None = None,
resume: str | None = None,
):
"""Initialize a W&B run from the internal run configuration."""
"""Initialize a W&B run from the internal run configuration.
This method can be entered as a context manager similar to `wandb.init` as follows:
```
with wandb_init_from_config(run_config, resume="must") as run:
# Use the initialized run here
...
```
"""
init_kwargs = dict(
id=config.run_id,
name=config.name,
Expand Down Expand Up @@ -78,8 +88,9 @@ def resolve_artifact_path(path: str | WandbArtifactConfig) -> str:
# TODO: We should use artifact.download() here to get the data directory
# But we need to point the download root at a volume mount, which isnt wired up yet
for entry in artifact.manifest.entries.values():
if entry.ref.startswith("file://"):
return str(Path(entry.ref.replace("file://", "")).parent)
match urlparse(entry.ref):
case ParseResult(scheme="file", path=file_path):
return str(Path(file_path).parent)
raise ValueError(f"Artifact {artifact.name} does not contain a filesystem reference.")
case _:
raise ValueError(f"Invalid artifact path: {path}")
Expand Down
2 changes: 1 addition & 1 deletion src/flamingo/jobs/finetuning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FinetuningJobConfig(BaseFlamingoConfig):
@root_validator(pre=True)
def ensure_tokenizer_config(cls, values):
"""Set the tokenizer to the model path when not explicitly provided."""
if values.get("tokenizer", None) is None:
if values.get("tokenizer") is None:
values["tokenizer"] = {}
match values["model"]:
case str() as model_path:
Expand Down

0 comments on commit 3cd9c48

Please sign in to comment.