Skip to content

Commit

Permalink
formatted
Browse files Browse the repository at this point in the history
Signed-off-by: Onkar Chougule <[email protected]>
  • Loading branch information
ochougul committed Jan 9, 2025
1 parent 1518387 commit 65ee919
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,11 @@ def compile(

if self.continuous_batching and full_batch_size is None:
raise TypeError("missing required argument: 'full_batch_size'")

if cache_size_multiplier and not full_batch_size:
raise ValueError("Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call")
raise ValueError(
"Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
)

# Define prefill specialization
prefill_specialization = {
Expand All @@ -303,8 +305,10 @@ def compile(
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
}

prefill_specialization.update({"full_batch_size": full_batch_size*cache_size_multiplier if cache_size_multiplier else full_batch_size}) if self.continuous_batching else None

prefill_specialization.update(
{"full_batch_size": full_batch_size * cache_size_multiplier if cache_size_multiplier else full_batch_size}
) if self.continuous_batching else None
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else None
specializations = [
prefill_specialization,
Expand All @@ -317,7 +321,13 @@ def compile(
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
"ctx_len": ctx_len,
}
decode_specialization.update({"full_batch_size": full_batch_size*cache_size_multiplier if cache_size_multiplier else full_batch_size}) if self.continuous_batching else ...
decode_specialization.update(
{
"full_batch_size": full_batch_size * cache_size_multiplier
if cache_size_multiplier
else full_batch_size
}
) if self.continuous_batching else ...
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
specializations.append(decode_specialization)

Expand Down

0 comments on commit 65ee919

Please sign in to comment.