Skip to content

Commit

Permalink
Merge pull request #337 from allenai/UnshardSkipKeys
Browse files Browse the repository at this point in the history
Skip keys during unsharding
  • Loading branch information
dirkgr authored Oct 25, 2023
2 parents 4980bad + 595720f commit a465caa
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions scripts/unshard.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import shutil
import sys
from concurrent.futures import ThreadPoolExecutor
from functools import reduce
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -90,11 +89,15 @@ def objects_are_equal(a: Any, b: Any) -> bool:
return a == b


def unshard(input_dir: Union[str, Path], output_dir: Union[str, Path]) -> None:
def unshard(
input_dir: Union[str, Path], output_dir: Union[str, Path], skip_keys: Optional[Set[str]] = None
) -> None:
if isinstance(input_dir, str):
input_dir = Path(input_dir)
if isinstance(output_dir, str):
output_dir = Path(output_dir)
if skip_keys is None:
skip_keys = set()
output_dir.mkdir(parents=True, exist_ok=True)

# Monkeypatch torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
Expand Down Expand Up @@ -130,7 +133,11 @@ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
shards_dict[shard_number] = executor.submit(torch.load, shard_name, map_location="cpu")
shards = [None] * len(shards_dict)
for rank, shard_future in shards_dict.items():
shards[rank] = shard_future.result()
shard = shard_future.result()
for key in skip_keys:
if key in shard:
del shard[key]
shards[rank] = shard
assert all(shard is not None for shard in shards)
executor.shutdown()
del shards_dict
Expand Down Expand Up @@ -184,9 +191,13 @@ def unshard_object(os: List[Any]) -> Any:


if __name__ == "__main__":
if len(sys.argv) != 3:
sys.stderr.write("Usage: unshard.py <input dir> <output dir>")
sys.exit(1)
else:
logging.basicConfig(level=logging.INFO)
unshard(sys.argv[1], sys.argv[2])
import argparse

parser = argparse.ArgumentParser(prog="unshard.py", description="Unshard sharded checkpoints on CPU")
parser.add_argument("input_dir")
parser.add_argument("output_dir")
parser.add_argument("--skip_key", nargs="*", default=[], action="extend")
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)
unshard(args.input_dir, args.output_dir, set(args.skip_key))

0 comments on commit a465caa

Please sign in to comment.