Skip to content

Commit 10e16a6

Browse files
committed
Minor fixes-1
Signed-off-by: Amit Raj <[email protected]>
1 parent d547d6f commit 10e16a6

File tree

5 files changed

+18
-15
lines changed

5 files changed

+18
-15
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from QEfficient.base.onnx_transforms import OnnxTransform
23-
from QEfficient.base.pytorch_transforms import PytorchTransform
23+
from QEfficient.base.pytorch_transforms import PytorchTransform, append_tranform
2424
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2525
from QEfficient.generation.cloud_infer import QAICInferenceSession
2626
from QEfficient.utils import constants, dump_qconfig
@@ -46,6 +46,7 @@ class QEFFBaseModel(ABC):
4646
def _transform_names(cls) -> List[str]:
4747
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
4848

49+
@append_tranform
4950
def __init__(self, model: torch.nn.Module) -> None:
5051
super().__init__()
5152
self.model = model

QEfficient/base/pytorch_transforms.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ class SplitGateUpWeightsTransform(PytorchTransform):
126126
@classmethod
127127
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
128128
transformed = False
129-
model = model.language_model
129+
130+
model = model.language_model if hasattr(model, "language_model") else model
131+
130132
num_layers = len(model.model.layers)
131133
delete_fused_key = True
132134
sd = model.state_dict()
@@ -158,3 +160,16 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
158160
print(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
159161
transformed = True
160162
return model, transformed
163+
164+
165+
VLM_SPLIT_GATE_UP_WEIGHTS = ["Llama4ForConditionalGeneration", "Llama4TextModel"]
166+
167+
168+
def append_tranform(func):
169+
def wrapper(*args, **kwargs):
170+
model_class = args[1].model.__class__.__name__ if hasattr(args[1], "model") else args[1].__class__.__name__
171+
if model_class in VLM_SPLIT_GATE_UP_WEIGHTS:
172+
args[0]._pytorch_transforms.append(SplitGateUpWeightsTransform)
173+
return func(*args, **kwargs)
174+
175+
return wrapper

QEfficient/transformers/modeling_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,3 @@ def _create_causal_mask(
384384
attention_mask = attention_mask.unsqueeze(1)
385385

386386
return attention_mask
387-
388-
389-
VLM_SPLIT_GATE_UP_WEIGHTS = ["Llama4ForConditionalGeneration"]

QEfficient/transformers/models/modeling_auto.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727
import QEfficient
2828
from QEfficient.base.modeling_qeff import QEFFBaseModel
2929
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
30-
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
3130
from QEfficient.generation.cloud_infer import QAICInferenceSession
3231
from QEfficient.generation.text_generation_inference import (
3332
CloudAI100ExecInfoNew,
3433
PerfMetrics,
3534
calculate_latency,
3635
get_compilation_dims,
3736
)
38-
from QEfficient.transformers.modeling_utils import VLM_SPLIT_GATE_UP_WEIGHTS
3937
from QEfficient.transformers.models.pytorch_transforms import (
4038
CustomOpsTransform,
4139
KVCacheModuleMethodMapperTransform,
@@ -469,9 +467,6 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
469467
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
470468

471469
def __init__(self, model):
472-
if model.config.architectures[0] in VLM_SPLIT_GATE_UP_WEIGHTS:
473-
self._pytorch_transforms.append(SplitGateUpWeightsTransform)
474-
475470
super().__init__(model)
476471
self.model = model.get_qeff_language_decoder()
477472

examples/llama4_lm_example.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,12 @@
1313
from QEfficient.utils.constants import Constants
1414
from QEfficient.utils.run_utils import ApiRunner
1515

16-
torch.manual_seed(42)
17-
1816
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
1917
model = Llama4ForCausalLM.from_pretrained(
2018
model_id, torch_dtype=torch.float32, use_cache=True, attn_implementation="eager"
2119
)
2220
model.eval()
2321

24-
original_sd = model.state_dict()
25-
2622
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id)
2723
config = model.config
2824
batch_size = len(Constants.INPUT_STR)
@@ -37,7 +33,6 @@
3733

3834
qeff_model = QEFFAutoModelForCausalLM(model)
3935

40-
onnx_model_path = qeff_model.export()
4136
qpc_path = qeff_model.compile(
4237
prefill_seq_len=128,
4338
ctx_len=2048,

0 commit comments

Comments
 (0)