Skip to content

Commit

Permalink
Fix finite lorax generation in cb mode
Browse files Browse the repository at this point in the history
Signed-off-by: Jou-An Chen <[email protected]>
  • Loading branch information
quic-jouachen authored and quic-rishinr committed Jan 11, 2025
1 parent 1517d6a commit 8602096
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
4 changes: 3 additions & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ def cloud_ai_100_exec_kv(
perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time),
)
else:
exec_info = generate_text.generate(prompt=prompt, generation_len=generation_len)
exec_info = generate_text.generate(
prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping
)

print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation)
return exec_info
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/peft/lora/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ def generate(
self,
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
prompts: List[str],
device_id: List[int] = None,
prompt_to_adapter_mapping: List[str] = None,
runtime: str = "AI_100",
device_id: Optional[List[int]] = None,
runtime: Optional[str] = "AI_100",
**kwargs,
):
"""
Expand All @@ -355,9 +355,9 @@ def generate(
``Mandatory`` Args:
:tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer): The tokenizer used in the inference
:prompts (List[str]): List of prompts to run the execution.
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
:prompt_to_adapter_mapping (List[str]): The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts."base" for base model (no adapter).
``optional`` Args:
:device_id (List[int]): Device IDs to be used for execution. If ``len(device_id) > 1``, it enables multiple card setup. If ``None``, auto-device-picker will be used. ``Defaults to None``.
:runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
"""
Expand Down
30 changes: 28 additions & 2 deletions tests/peft/lora/test_lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,12 @@ def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adap
assert qeff_model.unload_adapter("adapter_0") # valid unload


# test the export, export caching, compile, generate workflow
# test the export, export caching, compile and generate workflow in noncb mode
@pytest.mark.on_qaic
@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1])
def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path):
def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
base_model_name, adapter_id_0, adapter_id_1, tmp_path
):
qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1)

qeff_model.load_adapter(adapter_id_0, "adapter_0")
Expand Down Expand Up @@ -231,3 +233,27 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name,
prompts=prompts,
prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
)


# test the compile and generate workflow in cb mode
@pytest.mark.on_qaic
@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1])
def test_auto_lora_model_for_causal_lm_cb_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path):
qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(
base_model_name, continuous_batching=True, num_hidden_layers=1
)

qeff_model.load_adapter(adapter_id_0, "adapter_0")
qeff_model.load_adapter(adapter_id_1, "adapter_1")

# test compile
qeff_model.compile(prefill_seq_len=32, ctx_len=64, full_batch_size=2)
assert Path(qeff_model.qpc_path).is_dir()

# test generate
prompts = ["hello!", "hi", "hello, my name is", "hey"]
qeff_model.generate(
tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
prompts=prompts,
prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
)

0 comments on commit 8602096

Please sign in to comment.