Skip to content

Commit

Permalink
Add LoRA support in TE (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
hemildesai authored May 23, 2024
1 parent bc21964 commit 4c2f093
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions praxis/contrib/gpu/scripts_gpu/te_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from absl import logging

from praxis import base_layer
from praxis import pax_fiddle
from praxis import pytypes
Expand All @@ -9,6 +11,11 @@
from praxis.layers import attentions, grouped_query_attention, multi_query_attention
from praxis.layers import embedding_softmax
from praxis.layers import normalizations
from praxis.contrib.gpu.scripts_gpu.lora_layers import (
LoraAttentionProjection,
LoraCombinedQKVProjection,
LoraLinear,
)

try:
import transformer_engine.jax as te
Expand Down Expand Up @@ -233,6 +240,74 @@ def update_attn_te_tpl(te_tpl, attn_tpl):
assert len(stacked_transformer_obj.moe_layers) == 0
assert stacked_transformer_obj.ngrammer_tpls is None

def update_lora_te_tpl(te_tpl, transformer_layer_tpl):
lora_enabled = False
te_lora_scope = "none"
lora_rank = None
if (
transformer_layer_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl.__fn_or_cls__
is LoraLinear
):
lora_enabled = True
mlp_included_in_lora = True
current_rank = (
transformer_layer_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl.rank
)
lora_rank = (
current_rank if lora_rank is None else lora_rank & current_rank
)

attention_included_in_lora = False
if (
hasattr(transformer_layer_tpl.tr_atten_tpl, "combined_qkv_proj_tpl")
and transformer_layer_tpl.tr_atten_tpl.combined_qkv_proj_tpl.__fn_or_cls__
is LoraCombinedQKVProjection
):
lora_enabled = True
attention_included_in_lora = True
current_rank = (
transformer_layer_tpl.tr_atten_tpl.combined_qkv_proj_tpl.rank
)
lora_rank = (
current_rank if lora_rank is None else lora_rank & current_rank
)

if (
hasattr(transformer_layer_tpl.tr_atten_tpl, "proj_tpl")
and transformer_layer_tpl.tr_atten_tpl.proj_tpl.__fn_or_cls__
is LoraAttentionProjection
):
lora_enabled = True
attention_included_in_lora = True
current_rank = transformer_layer_tpl.tr_atten_tpl.proj_tpl.rank
lora_rank = (
current_rank if lora_rank is None else lora_rank & current_rank
)

if lora_enabled:
assert (
lora_rank > 0
), "LoRA rank should be the same for all layers and greater than 0."
if attention_included_in_lora and mlp_included_in_lora:
te_lora_scope = "all"
elif attention_included_in_lora and not mlp_included_in_lora:
te_lora_scope = "exclude_mlp"
elif mlp_included_in_lora and not attention_included_in_lora:
te_lora_scope = "mlp"

te_transformer_tpl.low_rank_adaptation_scope = te_lora_scope
te_transformer_tpl.low_rank_adaptation_dim = lora_rank

return te_tpl

try:
te_transformer_tpl = update_lora_te_tpl(
te_transformer_tpl, transformer_layer_tpl
)
except Exception as e:
logging.warning(f"Unable to use LoRA with TE: {e}")


return te_transformer_tpl

@staticmethod
Expand Down

0 comments on commit 4c2f093

Please sign in to comment.