Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage cleaner] Add wandb path implementation #400

Merged
merged 13 commits into from
Dec 15, 2023
113 changes: 109 additions & 4 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import botocore.exceptions as boto_exceptions
import google.cloud.storage as gcs
import torch
import wandb
from cached_path import add_scheme_client, cached_path, set_cache_dir
from cached_path.schemes import S3Client
from google.api_core.exceptions import NotFound
from omegaconf import OmegaConf as om
from rich.progress import Progress, TaskID, track

from olmo import util
Expand Down Expand Up @@ -946,8 +948,108 @@ def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: Un
_unshard_checkpoints(storage, run_dir_or_archive, checkpoints_dest_dir, config)


def _get_wandb_runs_from_wandb_dir(storage: StorageAdapter, wandb_dir: str, run_config: TrainConfig) -> List:
# For some reason, we often have a redundant nested wandb directory. Step into it here.
nested_wandb_dir = os.path.join(wandb_dir, "wandb/")
if storage.is_dir(nested_wandb_dir):
wandb_dir = nested_wandb_dir

# Wandb run directory names are stored in format <run>-<timestamp>-<id>
# https://docs.wandb.ai/guides/track/save-restore#examples-of-wandbsave
dir_names = storage.list_dirs(wandb_dir)
wandb_run_dir_names = [dir_name for dir_name in dir_names if dir_name.startswith("run")]
if len(wandb_run_dir_names) == 0:
log.warning("No wandb run directories found in wandb dir %s", wandb_dir)
return []

wandb_ids = [dir_name.split("-")[2] for dir_name in wandb_run_dir_names if dir_name.count("-") >= 2]

log.debug("Wandb ids: %s", wandb_ids)

assert run_config.wandb is not None
api: wandb.Api = wandb.Api()
return [api.run(path=f"{run_config.wandb.entity}/{run_config.wandb.project}/{id}") for id in wandb_ids]


def _get_wandb_path_from_run(wandb_run) -> str:
return "/".join(wandb_run.path)


def _get_wandb_runs_from_train_config(config: TrainConfig) -> List:
assert config.wandb is not None

run_filters = {
"display_name": config.wandb.name,
}
if config.wandb.group is not None:
run_filters["group"] = config.wandb.group

log.debug("Wandb entity/project: %s/%s", config.wandb.entity, config.wandb.project)
log.debug("Wandb filters: %s", run_filters)

api = wandb.Api()
return api.runs(path=f"{config.wandb.entity}/{config.wandb.project}", filters=run_filters)


def _are_equal_configs(wandb_config: TrainConfig, train_config: TrainConfig) -> bool:
return wandb_config.asdict(exclude=["wandb"]) == train_config.asdict(exclude=["wandb"])


def _get_wandb_config(wandb_run) -> TrainConfig:
local_storage = LocalFileSystemAdapter()
temp_file = local_storage.create_temp_file(suffix=".yaml")

om.save(config=wandb_run.config, f=temp_file)
wandb_config = TrainConfig.load(temp_file)

return wandb_config


def _get_matching_wandb_runs(wandb_runs, training_run_dir: str) -> List:
config_path = os.path.join(training_run_dir, CONFIG_YAML)
local_config_path = cached_path(config_path)
train_config = TrainConfig.load(local_config_path)

return [
wandb_run for wandb_run in wandb_runs if _are_equal_configs(_get_wandb_config(wandb_run), train_config)
]


def _get_wandb_path(run_dir: str) -> str:
raise NotImplementedError()
run_dir_storage = _get_storage_adapter_for_path(run_dir)

config_path = os.path.join(run_dir, CONFIG_YAML)
if not run_dir_storage.is_file(config_path):
raise FileNotFoundError("No config file found in run dir, cannot get wandb path")

local_config_path = cached_path(config_path)
config = TrainConfig.load(local_config_path, validate_paths=False)

if config.wandb is None or config.wandb.entity is None or config.wandb.project is None:
raise ValueError(f"Run at {run_dir} has missing wandb config, cannot get wandb run path")

wandb_runs = []

wandb_dir = os.path.join(run_dir, "wandb/")
if run_dir_storage.is_dir(wandb_dir):
wandb_runs += _get_wandb_runs_from_wandb_dir(run_dir_storage, wandb_dir, config)

wandb_runs += _get_wandb_runs_from_train_config(config)

# Remove duplicate wandb runs based on run path, and wandb runs that do not match our run.
wandb_runs = list({_get_wandb_path_from_run(wandb_run): wandb_run for wandb_run in wandb_runs}.values())
wandb_matching_runs = _get_matching_wandb_runs(wandb_runs, run_dir)

if len(wandb_matching_runs) == 0:
raise RuntimeError(f"Failed to find any wandb runs for {run_dir}. Run might no longer exist")

if len(wandb_matching_runs) > 1:
wandb_run_urls = [wandb_run.url for wandb_run in wandb_matching_runs]
raise RuntimeError(
f"Found {len(wandb_matching_runs)} runs matching run dir {run_dir}, cannot determine correct run: {wandb_run_urls}"
)

return _get_wandb_path_from_run(wandb_matching_runs[0])


def _append_wandb_path(
Expand All @@ -961,9 +1063,11 @@ def _append_wandb_path(

if _is_archive(run_dir_or_archive, run_dir_or_archive_storage) and append_archive_extension:
archive_extension = "".join(Path(run_dir_or_archive).suffixes)
wandb_path = wandb_path + archive_extension
relative_wandb_path = wandb_path + archive_extension
else:
relative_wandb_path = wandb_path + "/"

return os.path.join(base_dir, wandb_path)
return os.path.join(base_dir, relative_wandb_path)


def _copy(src_path: str, dest_path: str, temp_dir: str):
Expand Down Expand Up @@ -1045,7 +1149,7 @@ def _move_run(src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: st

src_move_path, dest_move_path = _get_src_and_dest_for_copy(src_storage, run_dir_or_archive, dest_dir, config)

if src_move_path == dest_move_path:
if src_move_path.rstrip("/") == dest_move_path.rstrip("/"):
# This could be a valid scenario if the user is, for example, trying to
# append wandb path to runs and this run has the right wandb path already.
log.info("Source and destination move paths are both %s, skipping", src_move_path)
Expand All @@ -1068,6 +1172,7 @@ def _move_run(src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: st
def move_run(run_path: str, dest_dir: str, config: MoveRunConfig):
storage = _get_storage_adapter_for_path(run_path)
run_dir_or_archive = _format_dir_or_archive_path(storage, run_path)
dest_dir = f"{dest_dir}/" if not dest_dir.endswith("/") else dest_dir
_move_run(storage, run_dir_or_archive, dest_dir, config)


Expand Down
Loading