diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 05e7fbb65..2d9e2abc3 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -632,6 +632,7 @@ class MoveRunConfig(StorageCleanerConfig): append_wandb_path: bool keep_src: bool store_archived: bool + entry: Optional[str] def _get_storage_adapter_for_path(path: str) -> StorageAdapter: @@ -1166,8 +1167,11 @@ def _get_src_dest_pairs_for_copy( src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: str, config: MoveRunConfig ) -> List[Tuple[str, str]]: is_archive_file = _is_archive(run_dir_or_archive, src_storage) + + if is_archive_file and config.entry is not None: + raise NotImplementedError("Cannot move only an entry if run is an archive file") if is_archive_file and config.append_wandb_path: - raise ValueError("Cannot append wandb path for run archive files") + raise NotImplementedError("Cannot append wandb path for run archive files") if is_archive_file: if config.store_archived: @@ -1184,15 +1188,29 @@ def _get_src_dest_pairs_for_copy( run_dir_storage = _get_storage_adapter_for_path(run_dir) if not run_dir_storage.is_dir(run_dir): raise ValueError(f"Run directory {run_dir} does not exist") + if config.entry is not None: + entry_src_path = os.path.join(run_dir, config.entry) + if not run_dir_storage.is_dir(entry_src_path) and not run_dir_storage.is_file(entry_src_path): + raise ValueError(f"Entry {config.entry} does not exist in directory {run_dir}") if not config.append_wandb_path: - return [(run_dir, dest_dir)] + if config.entry is None: + return [(run_dir, dest_dir)] + + entry_src_path = os.path.join(run_dir, config.entry) + entry_dest_path = os.path.join(dest_dir, config.entry) + return [(entry_src_path, entry_dest_path)] assert config.append_wandb_path and not is_archive_file - checkpoint_dirs = _get_checkpoint_dirs(run_dir, run_dir_storage) + checkpoint_to_wandb_path: Dict[str, str] # TODO: Update _get_wandb_path to get the wandb path for a checkpoint rather than a run directory. # A run directory could correspond to multiple wandb runs. - checkpoint_to_wandb_path = {checkpoint_dir: _get_wandb_path(run_dir) for checkpoint_dir in checkpoint_dirs} + if config.entry is not None and _is_checkpoint_dir(entry_path := os.path.join(run_dir, config.entry)): + # No need to consider other checkpoints if we are filtering for a specific checkpoint + checkpoint_to_wandb_path = {entry_path: _get_wandb_path(run_dir)} + else: + checkpoint_dirs = _get_checkpoint_dirs(run_dir, run_dir_storage) + checkpoint_to_wandb_path = {checkpoint_dir: _get_wandb_path(run_dir) for checkpoint_dir in checkpoint_dirs} src_dest_pairs: List[Tuple[str, str]] = [] # Mappings of source checkpoint directories to destination checkpoint directories @@ -1211,17 +1229,27 @@ def _get_src_dest_pairs_for_copy( for wandb_path in set(checkpoint_to_wandb_path.values()) ] + if config.entry is not None: + src_dest_pairs = [ + src_dest_pair for src_dest_pair in src_dest_pairs if Path(src_dest_pair[0]).match(config.entry) + ] + return src_dest_pairs def _move_run(src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: str, config: MoveRunConfig): log.info("Moving run directory or archive %s to directory %s", run_dir_or_archive, dest_dir) + if not config.keep_src and config.entry is not None: + raise ValueError("Cannot delete source when an entry of the run is specified.") + dest_storage = _get_storage_adapter_for_path(dest_dir) if dest_storage.is_file(dest_dir): raise ValueError(f"Destination directory {dest_dir} is a file") src_dest_path_pairs = _get_src_dest_pairs_for_copy(src_storage, run_dir_or_archive, dest_dir, config) + if len(src_dest_path_pairs) == 0: + raise RuntimeError("Found no valid source-destination pairs to move.") for src_move_path, dest_move_path in src_dest_path_pairs: if src_move_path.rstrip("/") == dest_move_path.rstrip("/"): @@ -1335,6 +1363,7 @@ def perform_operation(args: argparse.Namespace): append_wandb_path=args.append_wandb_path, keep_src=args.keep_src, store_archived=args.store_archived, + entry=args.entry, ) if args.run_path is not None and args.dest_dir is not None: move_run(args.run_path, args.dest_dir, move_run_config) @@ -1342,6 +1371,8 @@ def perform_operation(args: argparse.Namespace): raise ValueError("Run path or dest dir not provided for moving run") else: raise NotImplementedError(args.op) + + log.info("Operation completed successfully!") finally: if Path(temp_dir).is_dir(): log.info("Deleting temp dir %s", temp_dir) @@ -1440,6 +1471,11 @@ def _add_move_subparser(subparsers: _SubParsersAction): action="store_true", help="If set, the wandb path for the run is found and appended to the destination dir. If the run is being stored as an archive file, wandb id is first removed from the wandb path and used as the filename.", ) + move_parser.add_argument( + "--entry", + default=None, + help="If provided, only the directory/file with this name within the run is moved. Example: 'step0-unsharded'.", + ) def get_parser() -> ArgumentParser: