From c1b4adda9eef315c95bd8a6314e72d6ab40d0c27 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Fri, 20 Oct 2023 17:28:45 -0700 Subject: [PATCH 1/2] Skip some keys while unsharding --- scripts/unshard.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/scripts/unshard.py b/scripts/unshard.py index 6a4f8301d..3ce890848 100644 --- a/scripts/unshard.py +++ b/scripts/unshard.py @@ -1,10 +1,9 @@ import logging import shutil -import sys from concurrent.futures import ThreadPoolExecutor from functools import reduce from pathlib import Path -from typing import Any, Dict, List, Tuple, Union, cast +from typing import Any, Dict, List, Tuple, Union, cast, Set, Optional import numpy as np import torch @@ -90,11 +89,15 @@ def objects_are_equal(a: Any, b: Any) -> bool: return a == b -def unshard(input_dir: Union[str, Path], output_dir: Union[str, Path]) -> None: +def unshard( + input_dir: Union[str, Path], output_dir: Union[str, Path], skip_keys: Optional[Set[str]] = None +) -> None: if isinstance(input_dir, str): input_dir = Path(input_dir) if isinstance(output_dir, str): output_dir = Path(output_dir) + if skip_keys is None: + skip_keys = set() output_dir.mkdir(parents=True, exist_ok=True) # Monkeypatch torch's ShardedTensor, so we can unpickle without having torch.distributed set up. @@ -130,7 +133,11 @@ def _rebuild_from_type_v2_monkey(func, new_type, args, state): shards_dict[shard_number] = executor.submit(torch.load, shard_name, map_location="cpu") shards = [None] * len(shards_dict) for rank, shard_future in shards_dict.items(): - shards[rank] = shard_future.result() + shard = shard_future.result() + for key in skip_keys: + if key in shard: + del shard[key] + shards[rank] = shard assert all(shard is not None for shard in shards) executor.shutdown() del shards_dict @@ -184,9 +191,13 @@ def unshard_object(os: List[Any]) -> Any: if __name__ == "__main__": - if len(sys.argv) != 3: - sys.stderr.write("Usage: unshard.py ") - sys.exit(1) - else: - logging.basicConfig(level=logging.INFO) - unshard(sys.argv[1], sys.argv[2]) + import argparse + + parser = argparse.ArgumentParser(prog="unshard.py", description="Unshard sharded checkpoints on CPU") + parser.add_argument("input_dir") + parser.add_argument("output_dir") + parser.add_argument("--skip_key", nargs="*", default=[], action="extend") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + unshard(args.input_dir, args.output_dir, set(args.skip_key)) From 595720f2c342baa89425f2c54c560e3115745b96 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 25 Oct 2023 14:37:44 -0700 Subject: [PATCH 2/2] isort --- scripts/unshard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/unshard.py b/scripts/unshard.py index 3ce890848..b97435a35 100644 --- a/scripts/unshard.py +++ b/scripts/unshard.py @@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import reduce from pathlib import Path -from typing import Any, Dict, List, Tuple, Union, cast, Set, Optional +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast import numpy as np import torch