Skip to content

Commit

Permalink
Model: Split cache creation into a common function
Browse files Browse the repository at this point in the history
Unifies the switch statement across both draft and model caches.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Aug 22, 2024
1 parent ecaddec commit 5002617
Showing 1 changed file with 40 additions and 52 deletions.
92 changes: 40 additions & 52 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,30 +548,11 @@ def progress(loaded_modules: int, total_modules: int)
if not self.quiet:
logger.info("Loading draft model: " + self.draft_config.model_dir)

if self.draft_cache_mode == "Q4":
self.draft_cache = ExLlamaV2Cache_Q4(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
elif self.draft_cache_mode == "Q6":
self.draft_cache = ExLlamaV2Cache_Q6(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
elif self.draft_cache_mode == "Q8":
self.draft_cache = ExLlamaV2Cache_Q8(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
else:
self.draft_cache = ExLlamaV2Cache(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
self.draft_cache = self.create_cache(
cache_mode=self.draft_cache_mode,
autosplit=True,
)

for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
Expand Down Expand Up @@ -601,34 +582,10 @@ def progress(loaded_modules: int, total_modules: int)
if value:
yield value

if self.cache_mode == "Q4":
self.cache = ExLlamaV2Cache_Q4(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
elif self.cache_mode == "Q6":
self.cache = ExLlamaV2Cache_Q6(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
elif self.cache_mode == "Q8":
self.cache = ExLlamaV2Cache_Q8(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
else:
self.cache = ExLlamaV2Cache(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
self.cache = self.create_cache(
cache_mode=self.cache_mode,
autosplit=self.gpu_split_auto,
)

# Load model with autosplit
if self.gpu_split_auto:
Expand All @@ -647,6 +604,37 @@ def progress(loaded_modules: int, total_modules: int)
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)

def create_cache(self, cache_mode: str, autosplit: bool):
match cache_mode:
case "Q4":
return ExLlamaV2Cache_Q4(
self.model,
max_seq_len=self.cache_size,
lazy=autosplit,
batch_size=1,
)
case "Q6":
return ExLlamaV2Cache_Q6(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
case "Q8":
return ExLlamaV2Cache_Q8(
self.model,
max_seq_len=self.cache_size,
lazy=autosplit,
batch_size=1,
)
case _:
return ExLlamaV2Cache(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)

async def create_generator(self):
try:
# Don't acquire locks unless a model is loaded
Expand Down

0 comments on commit 5002617

Please sign in to comment.