Skip to content

Commit

Permalink
move import back
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Jun 28, 2024
1 parent e200a3b commit f35bda9
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 394 deletions.
5 changes: 3 additions & 2 deletions apps/train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch

from flashmodels import Builder, Trainer, accelerate, arguments


def train():
torch.manual_seed(101)

# parse args
from flashmodels import arguments
args = arguments.parse()

# build model, tokenizer, loader, optimizer and lr_scheduler
# and use accelerator to speed up training
from flashmodels import Builder, Trainer, accelerate
builder = Builder(args)
model, loader, tokenizer = builder.build_model_dataloader()
model, loader = accelerate(model, loader, args)
Expand Down
6 changes: 4 additions & 2 deletions flashmodels/accelerators/cuda_llama_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def apply_checkpointing(self, model):
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance(LlamaDecoderLayer)
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
Expand All @@ -97,7 +97,9 @@ def fsdp(self, model):
convert_outputs_to_fp32(model.forward.__func__), model)

# Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP.
auto_wrap_policy = ModuleWrapPolicy({LlamaDecoderLayer, })
auto_wrap_policy = ModuleWrapPolicy({
LlamaDecoderLayer,
})

mixed_precision_policy = None
if self.args.fp16 or self.args.bf16:
Expand Down
273 changes: 0 additions & 273 deletions flashmodels/arguments.py.bak

This file was deleted.

5 changes: 0 additions & 5 deletions flashmodels/patch/__init__.py.bak

This file was deleted.

1 change: 0 additions & 1 deletion flashmodels/patch/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch_xla.core.xla_model as xm
from torch import nn
from torchacc.dist.tp import Mesh, mark_sharding
from transformer.cache_utils import Cache
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (ACT2FN, LlamaRMSNorm,
LlamaRotaryEmbedding,
Expand Down
2 changes: 2 additions & 0 deletions flashmodels/patch/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def rewrite_load():

def patch_llama(use_flash_attn):
patch.patch_llama(use_flash_attn)
from flashmodels.patch.llama_model import (LlamaAttention,
LlamaDecoderLayer, LlamaMLP)
if os.environ.get("ACC_LLAMA_TP") == "1":
transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP
if os.getenv("XLA_USE_SPMD") == "1":
Expand Down
Loading

0 comments on commit f35bda9

Please sign in to comment.