diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index ca2c31f1e..552e74dc3 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -17,9 +17,11 @@ import botocore.exceptions as boto_exceptions import google.cloud.storage as gcs import torch +import wandb from cached_path import add_scheme_client, cached_path, set_cache_dir from cached_path.schemes import S3Client from google.api_core.exceptions import NotFound +from omegaconf import OmegaConf as om from rich.progress import Progress, TaskID, track from olmo import util @@ -946,8 +948,108 @@ def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: Un _unshard_checkpoints(storage, run_dir_or_archive, checkpoints_dest_dir, config) +def _get_wandb_runs_from_wandb_dir(storage: StorageAdapter, wandb_dir: str, run_config: TrainConfig) -> List: + # For some reason, we often have a redundant nested wandb directory. Step into it here. + nested_wandb_dir = os.path.join(wandb_dir, "wandb/") + if storage.is_dir(nested_wandb_dir): + wandb_dir = nested_wandb_dir + + # Wandb run directory names are stored in format -- + # https://docs.wandb.ai/guides/track/save-restore#examples-of-wandbsave + dir_names = storage.list_dirs(wandb_dir) + wandb_run_dir_names = [dir_name for dir_name in dir_names if dir_name.startswith("run")] + if len(wandb_run_dir_names) == 0: + log.warning("No wandb run directories found in wandb dir %s", wandb_dir) + return [] + + wandb_ids = [dir_name.split("-")[2] for dir_name in wandb_run_dir_names if dir_name.count("-") >= 2] + + log.debug("Wandb ids: %s", wandb_ids) + + assert run_config.wandb is not None + api: wandb.Api = wandb.Api() + return [api.run(path=f"{run_config.wandb.entity}/{run_config.wandb.project}/{id}") for id in wandb_ids] + + +def _get_wandb_path_from_run(wandb_run) -> str: + return "/".join(wandb_run.path) + + +def _get_wandb_runs_from_train_config(config: TrainConfig) -> List: + assert config.wandb is not None + + run_filters = { + "display_name": config.wandb.name, + } + if config.wandb.group is not None: + run_filters["group"] = config.wandb.group + + log.debug("Wandb entity/project: %s/%s", config.wandb.entity, config.wandb.project) + log.debug("Wandb filters: %s", run_filters) + + api = wandb.Api() + return api.runs(path=f"{config.wandb.entity}/{config.wandb.project}", filters=run_filters) + + +def _are_equal_configs(wandb_config: TrainConfig, train_config: TrainConfig) -> bool: + return wandb_config.asdict(exclude=["wandb"]) == train_config.asdict(exclude=["wandb"]) + + +def _get_wandb_config(wandb_run) -> TrainConfig: + local_storage = LocalFileSystemAdapter() + temp_file = local_storage.create_temp_file(suffix=".yaml") + + om.save(config=wandb_run.config, f=temp_file) + wandb_config = TrainConfig.load(temp_file) + + return wandb_config + + +def _get_matching_wandb_runs(wandb_runs, training_run_dir: str) -> List: + config_path = os.path.join(training_run_dir, CONFIG_YAML) + local_config_path = cached_path(config_path) + train_config = TrainConfig.load(local_config_path) + + return [ + wandb_run for wandb_run in wandb_runs if _are_equal_configs(_get_wandb_config(wandb_run), train_config) + ] + + def _get_wandb_path(run_dir: str) -> str: - raise NotImplementedError() + run_dir_storage = _get_storage_adapter_for_path(run_dir) + + config_path = os.path.join(run_dir, CONFIG_YAML) + if not run_dir_storage.is_file(config_path): + raise FileNotFoundError("No config file found in run dir, cannot get wandb path") + + local_config_path = cached_path(config_path) + config = TrainConfig.load(local_config_path, validate_paths=False) + + if config.wandb is None or config.wandb.entity is None or config.wandb.project is None: + raise ValueError(f"Run at {run_dir} has missing wandb config, cannot get wandb run path") + + wandb_runs = [] + + wandb_dir = os.path.join(run_dir, "wandb/") + if run_dir_storage.is_dir(wandb_dir): + wandb_runs += _get_wandb_runs_from_wandb_dir(run_dir_storage, wandb_dir, config) + + wandb_runs += _get_wandb_runs_from_train_config(config) + + # Remove duplicate wandb runs based on run path, and wandb runs that do not match our run. + wandb_runs = list({_get_wandb_path_from_run(wandb_run): wandb_run for wandb_run in wandb_runs}.values()) + wandb_matching_runs = _get_matching_wandb_runs(wandb_runs, run_dir) + + if len(wandb_matching_runs) == 0: + raise RuntimeError(f"Failed to find any wandb runs for {run_dir}. Run might no longer exist") + + if len(wandb_matching_runs) > 1: + wandb_run_urls = [wandb_run.url for wandb_run in wandb_matching_runs] + raise RuntimeError( + f"Found {len(wandb_matching_runs)} runs matching run dir {run_dir}, cannot determine correct run: {wandb_run_urls}" + ) + + return _get_wandb_path_from_run(wandb_matching_runs[0]) def _append_wandb_path( @@ -961,9 +1063,11 @@ def _append_wandb_path( if _is_archive(run_dir_or_archive, run_dir_or_archive_storage) and append_archive_extension: archive_extension = "".join(Path(run_dir_or_archive).suffixes) - wandb_path = wandb_path + archive_extension + relative_wandb_path = wandb_path + archive_extension + else: + relative_wandb_path = wandb_path + "/" - return os.path.join(base_dir, wandb_path) + return os.path.join(base_dir, relative_wandb_path) def _copy(src_path: str, dest_path: str, temp_dir: str): @@ -1045,7 +1149,7 @@ def _move_run(src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: st src_move_path, dest_move_path = _get_src_and_dest_for_copy(src_storage, run_dir_or_archive, dest_dir, config) - if src_move_path == dest_move_path: + if src_move_path.rstrip("/") == dest_move_path.rstrip("/"): # This could be a valid scenario if the user is, for example, trying to # append wandb path to runs and this run has the right wandb path already. log.info("Source and destination move paths are both %s, skipping", src_move_path) @@ -1068,6 +1172,7 @@ def _move_run(src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: st def move_run(run_path: str, dest_dir: str, config: MoveRunConfig): storage = _get_storage_adapter_for_path(run_path) run_dir_or_archive = _format_dir_or_archive_path(storage, run_path) + dest_dir = f"{dest_dir}/" if not dest_dir.endswith("/") else dest_dir _move_run(storage, run_dir_or_archive, dest_dir, config)