Skip to content

Commit

Permalink
try out fsdp sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed May 2, 2023
1 parent ffdef0e commit 54bee5d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions llama_device_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def main(args):
max_memory = (
{0: forty_gb, 1: forty_gb}
if not use_8bit
# this is a hack since infer_auto_device_map doesn't detect 8bit
# even if we load it in 8bit
# for big models, it'll start allocating to disk
# this is a hack since infer_auto_device_map can't detect
# that we're using 8bit, since we inited an empty model
# to analyse.
else {0: forty_gb * 2, 1: forty_gb * 2}
)
autodevice_map = infer_auto_device_map(
Expand Down
2 changes: 1 addition & 1 deletion llama_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run_inference(
sync_module_states=False,
limit_all_gathers=False,
forward_prefetch=True,
strategy=strategy,
sharding_strategy=strategy,
)

if rank == 0:
Expand Down

0 comments on commit 54bee5d

Please sign in to comment.