Skip to content

Commit

Permalink
fix usage of acclerate
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Jun 28, 2024
1 parent 641b452 commit e200a3b
Show file tree
Hide file tree
Showing 16 changed files with 542 additions and 175 deletions.
2 changes: 1 addition & 1 deletion examples/llama_acc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -ex

# FSDP
./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4
./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 --use_flash_attn

# TP
# ./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 24 --tp 4
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_baichuan_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self, model):
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_gemma_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def accelerate(self, model, loader):

def accelerate_internal(self, model, loader):
config = self.get_config()
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self):
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_glm_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self, model):
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_gpt_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

device = lazy_device()
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_llama_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def accelerate_internal(self, model, loader):
self.args.sp)

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)

if self.args.tp_num > 1 and self.args.pp_num > 1:
self.parallel_3d(model._get_underlay_model())
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_olmo_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader
else:
raise NotImplementedError("Currently, only FSDP is supported.")
Expand Down
2 changes: 1 addition & 1 deletion flashmodels/accelerators/acc_qwen_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def accelerate_internal(self, model, loader):
raise NotImplementedError("resume_from_checkpoint.")

config = self.get_config(model)
model = ta.accelerate(model, config)
model = ta.accelerate(model, config=config)
return model, loader

def get_config(self, model):
Expand Down
7 changes: 3 additions & 4 deletions flashmodels/accelerators/cuda_llama_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

from flashmodels.accelerators.accelerator import (Accelerator,
AcceleratorFactory)
Expand Down Expand Up @@ -70,7 +71,7 @@ def apply_checkpointing(self, model):
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance(submodule, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
check_fn = lambda submodule: isinstance(LlamaDecoderLayer)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
Expand All @@ -96,9 +97,7 @@ def fsdp(self, model):
convert_outputs_to_fp32(model.forward.__func__), model)

# Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP.
auto_wrap_policy = ModuleWrapPolicy({
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
})
auto_wrap_policy = ModuleWrapPolicy({LlamaDecoderLayer, })

mixed_precision_policy = None
if self.args.fp16 or self.args.bf16:
Expand Down
Loading

0 comments on commit e200a3b

Please sign in to comment.