From f35bda90ad284275d974f1523a6ef5dfb4976461 Mon Sep 17 00:00:00 2001 From: Seventeen17 <17aloha@gmail.com> Date: Fri, 28 Jun 2024 15:53:55 +0800 Subject: [PATCH] move import back --- apps/train.py | 5 +- .../accelerators/cuda_llama_accelerator.py | 6 +- flashmodels/arguments.py.bak | 273 ------------------ flashmodels/patch/__init__.py.bak | 5 - flashmodels/patch/llama_model.py | 1 - flashmodels/patch/patch.py | 2 + flashmodels/patch/patch.py.bak | 111 ------- 7 files changed, 9 insertions(+), 394 deletions(-) delete mode 100644 flashmodels/arguments.py.bak delete mode 100644 flashmodels/patch/__init__.py.bak delete mode 100644 flashmodels/patch/patch.py.bak diff --git a/apps/train.py b/apps/train.py index b6d790a..7c0a48a 100644 --- a/apps/train.py +++ b/apps/train.py @@ -1,15 +1,16 @@ import torch +from flashmodels import Builder, Trainer, accelerate, arguments + + def train(): torch.manual_seed(101) # parse args - from flashmodels import arguments args = arguments.parse() # build model, tokenizer, loader, optimizer and lr_scheduler # and use accelerator to speed up training - from flashmodels import Builder, Trainer, accelerate builder = Builder(args) model, loader, tokenizer = builder.build_model_dataloader() model, loader = accelerate(model, loader, args) diff --git a/flashmodels/accelerators/cuda_llama_accelerator.py b/flashmodels/accelerators/cuda_llama_accelerator.py index a27da9d..288f771 100644 --- a/flashmodels/accelerators/cuda_llama_accelerator.py +++ b/flashmodels/accelerators/cuda_llama_accelerator.py @@ -71,7 +71,7 @@ def apply_checkpointing(self, model): checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) - check_fn = lambda submodule: isinstance(LlamaDecoderLayer) + check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, @@ -97,7 +97,9 @@ def fsdp(self, model): convert_outputs_to_fp32(model.forward.__func__), model) # Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP. - auto_wrap_policy = ModuleWrapPolicy({LlamaDecoderLayer, }) + auto_wrap_policy = ModuleWrapPolicy({ + LlamaDecoderLayer, + }) mixed_precision_policy = None if self.args.fp16 or self.args.bf16: diff --git a/flashmodels/arguments.py.bak b/flashmodels/arguments.py.bak deleted file mode 100644 index 8ad9eb7..0000000 --- a/flashmodels/arguments.py.bak +++ /dev/null @@ -1,273 +0,0 @@ -import argparse -import os - -import torch - -from flashmodels.logger import logger -from flashmodels.patch import patch_gemma, patch_llama, patch_peft - - -def print_args(args): - logger.info("FlashModels Arguments: ") - logger.info(" \n".join(f" {k} = {v}" for k, v in vars(args).items())) - - -def parse(): - parser = argparse.ArgumentParser(description="Flash Models Arguments") - - # model args - parser.add_argument("--model_name_or_path", - type=str, - default="decapoda-research/llama-7b-hf") - parser.add_argument("--cache_dir", type=str, default="./models/") - parser.add_argument("--max_seq_length", type=int, default=1024) - parser.add_argument( - "--model_type", - type=str, - default="", - choices=["gpt", "llama", "glm", "baichuan", "qwen", "olmo"]) - - # dataset args - parser.add_argument("--dataset_name_or_path", - type=str, - default="./data/wikitext-2-raw-v1.json") - parser.add_argument("--dataset_config", type=str, default="") - parser.add_argument("--micro_batch_size", type=int, default=8) - parser.add_argument("--padding_side", type=str, default="right") - parser.add_argument("--disable_train_sampler", - action="store_true", - help="Disable Train Sampler") - - # accelerator args - parser.add_argument("--accelerator", - type=str, - default="acc", - choices=["cuda", "acc", "megatron"], - help="accelerator name") - parser.add_argument("--fsdp_num", - type=int, - default=1, - help="Full sharded data parallel Number") - parser.add_argument("--gc", - action="store_true", - default=False, - help="Use gradients checkpoint") - parser.add_argument( - "--gc_cnt", - type=int, - default=None, - help="Number of decoder layers for gradient checkpointing") - parser.add_argument("--tp_num", - type=int, - default=1, - help="Tensor Parallel Number") - parser.add_argument("--sp", - action="store_true", - default=False, - help="Use Sequence Parallelism.") - parser.add_argument( - "--sp_reshard_after_forward", - action="store_true", - default=False, - help="To reduce memory usage, reshard weight after forward in TP-SP, \ - and perform an extra all-gather in the backward pass") - parser.add_argument("--sp_num", - type=int, - default=1, - help="DeepSpeed Ulysses Sequence \ - Parallel Number. ") - parser.add_argument("--dp_num", - type=int, - default=1, - help="Data Parallel Number") - parser.add_argument("--pp_num", - type=int, - default=1, - help="Pipeline Parallel Number") - parser.add_argument("--fp16", - action="store_true", - help="Run model in fp16 mode.") - parser.add_argument("--bf16", - action="store_true", - help="Run model in bfloat16 mode.") - parser.add_argument("--force_use_syncfree_adam", - action="store_true", - help="Force to use \ - syncfree.Adam/AdamW for better tracing peformance.") - parser.add_argument("--use_zero2", - action="store_true", - help="Use \ - distributed optimizer(ZeRO2) for SPMD-DP.") - parser.add_argument("--use_zero3", - action="store_true", - help="Use \ - ZeRO3 for SPMD-DP.") - - # lora - parser.add_argument("--lora", action="store_true", help="Use lora") - parser.add_argument("--lora_r", - type=int, - default=8, - help="lora attention dimension") - parser.add_argument("--lora_alpha", - type=int, - default=8, - help="lora scaling alpha parameter") - parser.add_argument("--lora_dropout", - type=float, - default=0.0, - help="The dropout probability \ - for Lora layers") - parser.add_argument( - "--lora_target_modules", - type=str, - default="QKV", - choices=["QKV", "ALL"], - help="The modules to apply Lora to. ALL means all linear layers in \ - decoder layer use lora, QKV means only qkv linears use lora") - - # training args - parser.add_argument("--global_rank", type=int, default=0) - parser.add_argument("--resume_from_checkpoint", - action="store_true", - help="Resume from checkpoint, if true," - " load checkpoint from ckpt_dir") - parser.add_argument("--ckpt_dir", type=str, default="") - parser.add_argument("--ckpt_freq", - type=int, - default=100, - help="The checkpoint frequency of local steps.") - parser.add_argument("--profile", - action="store_true", - help="Open pytorch profiler") - parser.add_argument("--profile_dir", type=str, default="./profile/") - parser.add_argument("--profile_stop_step", - type=int, - default=10, - help="Maximum profiling steps") - parser.add_argument("--log_interval", type=int, default=1) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--max_step", type=int, default=-1) - parser.add_argument("--learning_rate", - type=float, - default=2e-5, - help="The initial learning rate for AdamW.") - parser.add_argument("--weight_decay", - type=float, - default=0.03, - help="Weight decay for AdamW if we apply some.") - parser.add_argument("--adam_beta1", - type=float, - default=0.9, - help="Beta1 for AdamW optimizer") - parser.add_argument("--adam_beta2", - type=float, - default=0.999, - help="Beta2 for AdamW optimizer") - parser.add_argument("--adam_epsilon", - type=float, - default=1e-8, - help="Epsilon for AdamW optimizer.") - parser.add_argument("--max_grad_norm", - type=float, - default=1.0, - help="Max gradient norm.") - parser.add_argument( - "--lr_scheduler_type", - type=str, - default="cosine", - help="The scheduler type to use.", - choices=[ - "linear", "cosine", "cosine_with_restarts", "polynomial", - "constant", "constant_with_warmup" - ], - ) - parser.add_argument( - "--warmup_ratio", - type=float, - default=0.0, - help="Linear warmup over warmup_ratio fraction of total steps.") - parser.add_argument("--warmup_steps", - type=int, - default=0, - help="Linear warmup over warmup_steps.") - parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument("--padding_strategy", - type=str, - default="max_length", - help="tokenizer padding strategy", - choices=["max_length", "longest"]) - parser.add_argument("--max_train_steps", - type=int, - default=-1, - help="Maximum training steps") - parser.add_argument("--use_flash_attn", - action="store_true", - default=False, - help="Use TriDao FlashAttention2") - parser.add_argument("--log_loss", - action="store_true", - help="Print loss when logging steps") - - args = parser.parse_args() - - if args.lora: - patch_peft() - - if args.accelerator == "cuda": - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - - args.global_rank = int(os.getenv("RANK", 0)) - args.local_rank = int(os.getenv("LOCAL_RANK", 0)) - args.world_size = int(os.getenv("WORLD_SIZE", 1)) - if args.global_rank != 0: - args.profile = False - - # mkdir for ckpt_dir - if len(args.ckpt_dir) > 0: - os.makedirs(args.ckpt_dir, exist_ok=True) - - # amp checks. - args.dtype = torch.float - if args.fp16: - assert not args.bf16 - args.dtype = torch.half - - if args.bf16: - assert not args.fp16 - args.dtype = torch.bfloat16 - - # DP/MP checks. - args.mp_num = args.pp_num * args.tp_num # model parallel size. - args.dp_num = max( - 1, args.world_size // (args.mp_num * args.fsdp_num * args.sp_num)) - - if not args.model_type: - if "llama" in args.model_name_or_path.lower(): - args.model_type = "llama" - elif "gpt" in args.model_name_or_path.lower(): - args.model_type = "gpt" - elif "glm" in args.model_name_or_path.lower(): - args.model_type = "glm" - elif "baichuan" in args.model_name_or_path.lower(): - args.model_type = "baichuan" - elif "qwen" in args.model_name_or_path.lower(): - args.model_type = "qwen" - elif "olmo" in args.model_name_or_path.lower(): - args.model_type = "olmo" - elif "gemma" in args.model_name_or_path.lower(): - args.model_type = "gemma" - else: - raise NotImplementedError( - f"Unsupported model: {args.model_name_or_path}") - - if args.model_type == "llama" and args.accelerator == 'acc' and ( - args.fp16 or args.bf16): - patch_llama(args.use_flash_attn) - if args.model_type == "gemma" and args.accelerator == 'acc': - patch_gemma() - - if args.local_rank == 0: - print_args(args) - - return args diff --git a/flashmodels/patch/__init__.py.bak b/flashmodels/patch/__init__.py.bak deleted file mode 100644 index 452a833..0000000 --- a/flashmodels/patch/__init__.py.bak +++ /dev/null @@ -1,5 +0,0 @@ -from flashmodels.patch.patch import patch_gemma, patch_llama, patch_lora - - -def patch_peft(): - patch_lora() diff --git a/flashmodels/patch/llama_model.py b/flashmodels/patch/llama_model.py index 6b4a7ca..745735f 100644 --- a/flashmodels/patch/llama_model.py +++ b/flashmodels/patch/llama_model.py @@ -8,7 +8,6 @@ import torch_xla.core.xla_model as xm from torch import nn from torchacc.dist.tp import Mesh, mark_sharding -from transformer.cache_utils import Cache from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import (ACT2FN, LlamaRMSNorm, LlamaRotaryEmbedding, diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index bc76295..4d6d36f 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -33,6 +33,8 @@ def rewrite_load(): def patch_llama(use_flash_attn): patch.patch_llama(use_flash_attn) + from flashmodels.patch.llama_model import (LlamaAttention, + LlamaDecoderLayer, LlamaMLP) if os.environ.get("ACC_LLAMA_TP") == "1": transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP if os.getenv("XLA_USE_SPMD") == "1": diff --git a/flashmodels/patch/patch.py.bak b/flashmodels/patch/patch.py.bak deleted file mode 100644 index bc76295..0000000 --- a/flashmodels/patch/patch.py.bak +++ /dev/null @@ -1,111 +0,0 @@ -import difflib -import inspect -import os -import re -from typing import Any - -import torch -import torchacc.utils.patch as patch -import transformers - -from flashmodels.logger import logger - - -def rewrite_load(): - """Rewrite `torch.load` in `from_pretrain` in case to use mmap to reduce the CPU - memory pressure of loading multiple copies of data under multiple processes""" - source = inspect.getsource(transformers.modeling_utils) - modified = re.sub(r"torch\.load\((?![^)]*mmap[^)]*\))([^)]*)\)", - r"torch.load(\g<1>, mmap=True)", source) - modified = re.sub(r"partial\(torch.load,(?![^)]*mmap[^)]*\))([^)]*)\)", - r"partial(torch.load,\g<1>, mmap=True)", modified) - if (int(os.environ.get("LOCAL_RANK", 0)) == 0): - lines = difflib.ndiff(source.split("\n"), modified.split("\n")) - diff = "\n".join([ - line for line in lines - if line.startswith("+") or line.startswith("-") - ]) - logger.warning( - f"When set LOW_CPU_MEM_USAGE, all the `torch.load` in transfomers.modeling_utils " - f"are called with `mmap=True`, diff: \n{diff}") - exec(modified, transformers.modeling_utils.__dict__) - - -def patch_llama(use_flash_attn): - patch.patch_llama(use_flash_attn) - if os.environ.get("ACC_LLAMA_TP") == "1": - transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP - if os.getenv("XLA_USE_SPMD") == "1": - # use einsum in linear for SPMD TP/Ulysses. - transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer - - # (wenting.swt): Delete me when merged in transformers - if bool(int(os.environ.get("LOW_CPU_MEM_USAGE", "0"))): - rewrite_load() - - -def patch_gemma(): - # Set the attention_mask in GemmaAttention to None to match the pattern of FlashAttentionRewriter. - def wrap_for_flash_attention(func): - def wrapper(*args, **kwargs): - kwargs["attention_mask"] = None - return func(*args, **kwargs) - - return wrapper - - xla_flags = os.getenv('XLA_FLAGS', '').split(' ') - pattern = r'--xla_gpu_enable_flash_attention=(\w+)' - for flag in xla_flags: - match = re.search(pattern, flag) - if match: - value = match.group(1) - if str(value).lower() == "true": - transformers.models.gemma.modeling_gemma.GemmaAttention.forward = wrap_for_flash_attention( - transformers.models.gemma.modeling_gemma.GemmaAttention. - forward) - - -def patch_lora(): - try: - import peft - from peft.tuners import lora - except ImportError: - logger.errors("import lora fail, please install peft.") - - def _forward_linear(self, x: torch.Tensor, *args: Any, - **kwargs: Any) -> torch.Tensor: - if self.disable_adapters: - if self.merged: - self.unmerge() - if version.parse(peft.__version__) > version.parse("0.6.2"): - result = self.base_layer(x, *args, **kwargs) - else: - result = self._linear(x) - elif self.merged: - if version.parse(peft.__version__) > version.parse("0.6.2"): - result = self.base_layer(x, *args, **kwargs) - else: - result = self._linear(x) - else: - if version.parse(peft.__version__) > version.parse("0.6.2"): - result = self.base_layer(x, *args, **kwargs) - else: - result = self._linear(x) - torch_result_dtype = result.dtype - for active_adapter in self.active_adapters: - if active_adapter not in self.lora_A.keys(): - continue - lora_A = self.lora_A[active_adapter] - lora_B = self.lora_B[active_adapter] - dropout = self.lora_dropout[active_adapter] - scaling = self.scaling[active_adapter] - x = x.to(lora_A.weight.dtype) - result += lora_B(lora_A(dropout(x))) * scaling - - result = result.to(torch_result_dtype) - return result - - # TODO(baole): delete this patch after - # https://github.com/huggingface/peft/pull/1010 is merged. - lora.Linear.forward = _forward_linear