diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ff11531..d692233 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -90,6 +90,7 @@ class ExllamaV2Container: # GPU split vars gpu_split: Optional[list] = None + draft_gpu_split: Optional[list] = None gpu_split_auto: bool = True autosplit_reserve: List[float] = [96 * 1024**2] use_tp: bool = False @@ -180,6 +181,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): ) draft_model_path = draft_model_path / draft_model_name + self.draft_gpu_split = draft_args.get("draft_gpu_split") self.draft_model_dir = draft_model_path self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() @@ -232,6 +234,16 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): for value in autosplit_reserve_megabytes ] + if self.draft_gpu_split: + self.gpu_split_auto = False + self.gpu_split = gpu_split + + gpu_device_list = [ + device_idx + for device_idx, memory in enumerate(self.draft_gpu_split) + if memory > 0 + ] + # Hardcode max output length to 16 self.config.max_output_len = 16 @@ -617,21 +629,37 @@ def progress(loaded_modules: int, total_modules: int) # Draft uses the autosplit loader, so create a cache that reflects this draft_cache_class = self.get_cache_class(self.draft_cache_mode) - self.draft_cache = self.create_cache( - cache_class=draft_cache_class, - autosplit=True, - use_tp=False, - model=self.draft_model, - ) - for value in self.draft_model.load_autosplit_gen( - self.draft_cache, - reserve_vram=autosplit_reserve, - last_id_only=True, - callback_gen=progress_callback, - ): - if value: - yield value + if self.draft_gpu_split: + for value in self.draft_model.load_gen( + self.draft_gpu_split, + callback_gen=progress_callback, + ): + if value: + yield value + + self.draft_cache = self.create_cache( + cache_class=draft_cache_class, + autosplit=False, + use_tp=False, + model=self.draft_model, + ) + else: + self.draft_cache = self.create_cache( + cache_class=draft_cache_class, + autosplit=True, + use_tp=False, + model=self.draft_model, + ) + + for value in self.draft_model.load_autosplit_gen( + self.draft_cache, + reserve_vram=autosplit_reserve, + last_id_only=True, + callback_gen=progress_callback, + ): + if value: + yield value # Test VRAM allocation with a full-length forward pass input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)