Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Liger Fused Cross Entropy Kernel to FOAK Plugin #76

Open
achew010 opened this issue Aug 29, 2024 · 4 comments
Open

Introduce Liger Fused Cross Entropy Kernel to FOAK Plugin #76

achew010 opened this issue Aug 29, 2024 · 4 comments
Assignees
Labels
triton involves triton kernels tuning

Comments

@achew010
Copy link
Contributor

achew010 commented Aug 29, 2024

Description

Consider adding additional FusedCrossEntropyLoss kernel to FOAK set of kernels given the additional improvement seen using it in earlier tests (See Background below).

Considerations:

  • Enable usage of FusedCrossEntropyLoss in FOAK plugin
  • Enable FOAK kernels to be selectively activated for use in full finetuning and regular PEFT
  • (KIV) consider enabling chunked CE loss for current plugin

Background

A comparison of the current FOAK kernels against the kernels from Liger using Liger's full FT benchmark script with the following parameters;

  • model: "meta-llama/Meta-Llama-3-8B"
  • dataset: 10000 sample subset of "tatsu-lab/alpaca"
  • max_seq_length: 512
  • epoch: 1
  • num devices: 4 x A100-80GB GPUs

4 triton kernels are activated in the comparison against the FOAK equivalents,

  • Fused MLP (SwiGLU) (MLP fused with Activation)
  • RoPE Embeddings
  • RMSNorm
  • CrossEntropyLoss / FusedCrossEntropyLoss (Last linear layer fused with loss)

The benchmarks report the following metrics

  • avg_tokens_per_sec: Total input tokens seen by the model divided by the total runtime (secs) of each run.
  • total_peak_allocated_memory: Total peak allocated gpu memory in MB

We observe that the FOAK kernels matches Liger in both speed and memory consumption with all 4 kernels (using the unfused CrossEntropyLoss kernel) but Liger performs better with FusedCrossEntropyLoss for

  • speed (up to 20% improvement)
  • memory (up to 36% improvement)

image

Additional Notes

  • We also noticed that Liger's CrossEntropyLoss kernel doesn't support chunking of the LM vocab unlike the current FOAK kernels from Unsloth. Chunking allows the loss computation to be performed quicker in smaller chunks of the vocab before doing a final reduction over all the chunk losses. This could be a potential limitation/slowdown when the LM head of the model has a large vocab dimension (e.g. 256k)
  • Considering that the Liger Kernels appear to be drop-in replacements for FOAK kernels, we would expect that a mix of FOAK and Liger Kernels to be compatible in the current FOAK plugin for QPEFT.

Extracted from fms-acceleration FOAK slides

model_name_or_path framework_config num_gpus batch_size tokens_per_second % Increase in throughput peak_mem_alloc_in_GIB
llama3/hf/70b_pre_trained​ accelerated-peft-bnb​ 2​ 2​ 398​ 0​ ​ 49.0​
llama3/hf/70b_pre_trained​ accelerated-peft-bnb-foak​ 2​ 2​ 434​ 9​ 48.7​
llama3/hf/70b_pre_trained​ accelerated-peft-bnb-liger 2​ 2​ ?​ ?​ ?​
@fabianlim
Copy link
Contributor

@achew010 currently we only install the non-chuncked CE loss from unsloth. It seems to be OK for llama3, but scaling up, we should consider handling chunking also

@achew010
Copy link
Contributor Author

achew010 commented Sep 17, 2024

Considerations for Introducing FusedCrossEntropyLoss to FMS-Acceleration

Liger's FCELoss combines the LM head matmul with the CrossEntropyLoss kernel into a single operation
to reduce intermediate overheads

We can keep the additional FusedCrossEntropyLoss code inside plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger

There are 3 approaches to apply FusedCrossEntropyLoss to FMS-Acceleration, the 1st 2 options have tradeoffs in terms of maintainability and reliability. 3rd option requires additional documentation as there are certain nuances that might not be easy to understand.

  1. Follow Liger’s monkey patching
    1. Straightforward global patching and maintains the function signature
    2. Suffers maintainability when the CausalLM forward is updated everytime
    3. Less robust as the patching is not properly tracked and handled especially when done globally
  2. fms-acceleration custom forward patching
    1. instance patching using model patcher rules
    2. will also suffer maintainability issues when the CausalLM forward is updated everytime. See minimal implementation of approach 2 here. In the implementation, we internally maintain the modified LlamaForCausalLM forward function, this is hard to maintain for any other models we support later on.
      • If we try to avoid explicitly specifying the arguments to forward patch e.g. def forward(*args, **kwargs) this will modify the forward signature and will cause problems with HFTrainer detecting what to discard, you might end up with missing dataset keys errors - you can workaround the error by setting TrainingArguments.remove_unused_columns = False
    3. robust patching as the fms-acceleration patcher state will track all patching done
  3. fms-acceleration internally patching lm_head and loss computation to support fused linear+loss
    1. instance patching of self.lm_head and CrossEntropyLoss using model patcher rules, doesn’t affect the forward signature of CausalLM

    2. no maintainability issues since forward signature isn’t affected

    3. robust patching as the fms-acceleration patcher state will track all patching done

    4. requires some documentation of the implementation to explain the intuition behind the patches

      1. Patch self.lm_head forward with function to calculate the shifted hidden_state inside self.lm_head and holds it in scope
      2. return a placeholder output logits from self.lm_head with very small compute footprint since we won’t use it. Subsequent computation of the logits (bold) are unused operations
      3. Patch CrossEntropyLoss class to perform the fused matmul and loss computation with the shifted hidden_states in scope

      Snippet of LlamaModelForCausalLM

      hidden_states = outputs[0]
      if self.config.pretraining_tp > 1:
          lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
          logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
          logits = torch.cat(logits, dim=-1)
      else:
          if labels is None and not is_torchdynamo_compiling():
              logger.warning_once(
                  "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
              )
          # at this point the logits produced are not used, so keep its footprint small.
          # this patched operation just computes the `shifted hidden_states` and keeps it for later when computing the loss
          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
      
      loss = None
      if labels is not None:
          # dummy computations
          logits = logits.float()
          # dummy computation
          shift_logits = logits[..., :-1, :].contiguous()
          # labels will not be affected by patching
          shift_labels = labels[..., 1:].contiguous()
          # this loss function instantiated should be patched over as FusedCrossEntropyLoss
          loss_fct = CrossEntropyLoss()
          # dummy computation
          shift_logits = shift_logits.view(-1, self.config.vocab_size)
          # labels will not be affected by patching
          shift_labels = shift_labels.view(-1)
          shift_labels = shift_labels.to(shift_logits.device)
          # the patched loss function will access the pre-computed `shifted hidden_states` and perform the FusedCrossEntropyLoss computation here
          loss = loss_fct(shift_logits, shift_labels)
      

@anhuong
Copy link
Collaborator

anhuong commented Nov 10, 2024

Implemented (2) in this PR #93 but as noted above it suffers from maintainability as

  1. llama and mistral have very similar forward functions but mixtral has different forward function that also needs to be copied in and granite has a slight difference
  2. hard to maintain between different transformers versions
  3. was not able to show train_runtime improvement from other triton kernels but was able to show memory improvement (as noted in the PR)

Worked on implementing (3) via PR https://github.com/foundation-model-stack/fms-acceleration/compare/main...anhuong:fms-acceleration:fused-cross-entropyloss?expand=1 however the model patching rules aren't quite right here.

Thus it could be worth it to implement (2) via the PR for just transformers v4.44/4.43 and then implement the new solution with the model patching changes needed form #98.

@anhuong
Copy link
Collaborator

anhuong commented Dec 4, 2024

@fabianlim do you want to close this issue since we merged the liger PR or leave it open while we work towards the v2?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triton involves triton kernels tuning
Projects
None yet
Development

No branches or pull requests

4 participants