From 24c686b03c427e00f0905cb294339fd4c2968dbd Mon Sep 17 00:00:00 2001 From: Dvid Noel Ng Date: Fri, 12 Jan 2024 15:37:38 +0100 Subject: [PATCH 1/6] Ignore the .so file --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 8b56a79c..9b14f350 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ build/ __pycache__/ .idea venv -dist \ No newline at end of file +dist +*.so \ No newline at end of file From baa1efec3f6e55f9a27c07e6503855d48fe8b5f6 Mon Sep 17 00:00:00 2001 From: Dvid Noel Ng Date: Fri, 12 Jan 2024 16:52:30 +0100 Subject: [PATCH 2/6] Allow layer repeats --- exllamav2/model.py | 92 ++++++++++++++++++++++++++++------------- exllamav2/model_init.py | 16 ++++++- 2 files changed, 77 insertions(+), 31 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 8154b0ef..03575de7 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -130,14 +130,14 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): self.modules.append(ExLlamaV2Embedding(self, "model.embed_tokens")) self.modules_dict[self.modules[-1].key] = self.modules[-1] - for layer_idx in range(self.config.num_hidden_layers): + for layer_list in range(self.config.num_hidden_layers): - self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_idx}", layer_idx)) + self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_list}", layer_list)) for m in self.modules[-1].submodules: self.modules_dict[m.key] = m if self.config.architecture == "Mixtral": - self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_idx}", layer_idx)) + self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_list}", layer_list)) else: - self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_idx}", layer_idx)) + self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_list}", layer_list)) for m in self.modules[-1].submodules: self.modules_dict[m.key] = m @@ -150,15 +150,40 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): # Find last layer that affects k/v cache - layer_idx = len(self.modules) + layer_list = len(self.modules) while True: - layer_idx -= 1 - if isinstance(self.modules[layer_idx], ExLlamaV2Attention): + layer_list -= 1 + if isinstance(self.modules[layer_list], ExLlamaV2Attention): break - self.last_kv_layer_idx = layer_idx + self.last_kv_layer_idx = layer_list + if hasattr(config, 'repeats'): + self.layers = [] + + def listLeftIndex(alist, value): + if value == 0: + return 0 + return alist.index(str(value)) + + def listRightIndex(alist, value): + if value > len(alist): + return -1 + return len(alist) - alist[-1::-1].index(str(value)) -1 + + layer_list = [layer.key.split(".")[-1] for layer in self.modules] + + for interval in config.repeats: + start_idx = listLeftIndex(layer_list, interval[0]) + end_idx = listRightIndex(layer_list, interval[1]) + self.layers.extend(list(range(start_idx, end_idx + 1))) + self.layers.extend(list(range(listRightIndex(layer_list, config.repeats[-1][1]), len(layer_list)))) + + # If we have create a Frankenmerge, lets print it to verify! + for layer in self.layers: + print(layer, self.modules[layer].key) + def set_device_map(self, allocation, embed_cpu = True): self.cache_map = {} @@ -582,6 +607,23 @@ def _forward(self, return_last_state = False, position_offsets = None): + def process_module(module, x, last_state): + device = _torch_device(module.device_idx) + + if idx == self.head_layer_idx: + if last_id_only and return_last_state: + x = x.narrow(-2, -1, 1) + last_state = x + elif last_id_only: + x = x.narrow(-2, -1, 1) + elif return_last_state: + last_state = x.narrow(-2, -1, 1) + + x = safe_move_tensor(x, device) + x = module.forward(x, cache=cache, attn_params=attn_params, past_len=past_len, loras=loras) + + return x, last_state + batch_size, seq_len = input_ids.shape past_len = 0 if cache is not None: @@ -596,27 +638,19 @@ def _forward(self, attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets) last_state = None - for idx, module in enumerate(self.modules): - - device = _torch_device(module.device_idx) - - # Onward - - if idx == self.head_layer_idx: - if last_id_only and return_last_state: - x = x.narrow(-2, -1, 1) - last_state = x - elif last_id_only: - x = x.narrow(-2, -1, 1) - elif return_last_state: - last_state = x.narrow(-2, -1, 1) - - x = safe_move_tensor(x, device) - x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras) - - if preprocess_only and idx == self.last_kv_layer_idx: - x = None - break + if hasattr(self, 'layers'): + for i, idx in enumerate(self.layers): + module = self.modules[idx] + x, last_state = process_module(module, x, last_state) + if preprocess_only and idx == self.last_kv_layer_idx: + x = None + break + else: + for idx, module in enumerate(self.modules): + x, last_state = process_module(module, x, last_state) + if preprocess_only and idx == self.last_kv_layer_idx: + x = None + break # Advance cache diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index 834aa465..d6c3d094 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -1,5 +1,5 @@ -import argparse, sys, os, glob +import argparse, sys, os, glob, ast from exllamav2 import( ExLlamaV2, @@ -17,6 +17,7 @@ def add_args(parser): parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention") parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed") parser.add_argument("-ept", "--experts_per_token", type = int, help = "Override MoE model's default number of experts per token") + parser.add_argument("--repeats", type=parse_tuple_list, help="List of tuples of the layers to repeat") def print_options(args): @@ -60,6 +61,16 @@ def check_args(args): print(f" ## Error: Cannot find {filename} in {args.model_dir}") sys.exit() +def parse_tuple_list(string): + try: + # Safely evaluate the string as a Python literal (list of tuples) + tuple_list = ast.literal_eval(string) + if not all(isinstance(item, tuple) for item in tuple_list): + raise ValueError + return tuple_list + except: + raise argparse.ArgumentTypeError("Input must be a list of tuples") + def init(args, quiet = False, allow_auto_split = False, skip_load = False): @@ -76,7 +87,8 @@ def init(args, quiet = False, allow_auto_split = False, skip_load = False): if args.rope_alpha: config.scale_alpha_value = args.rope_alpha config.no_flash_attn = args.no_flash_attn if args.experts_per_token: config.num_experts_per_token = args.experts_per_token - + if args.repeats: config.repeats = args.repeats + # Set low-mem options if args.low_mem: config.set_low_mem() From 422064282312db47dc3c3a021a7be4e3a37d6f58 Mon Sep 17 00:00:00 2001 From: Dvid Noel Ng Date: Fri, 12 Jan 2024 17:36:13 +0100 Subject: [PATCH 3/6] Rename and tidy layer printing --- exllamav2/model.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 03575de7..f044549c 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -160,7 +160,7 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): if hasattr(config, 'repeats'): - self.layers = [] + self.layers_list = [] def listLeftIndex(alist, value): if value == 0: @@ -177,12 +177,13 @@ def listRightIndex(alist, value): for interval in config.repeats: start_idx = listLeftIndex(layer_list, interval[0]) end_idx = listRightIndex(layer_list, interval[1]) - self.layers.extend(list(range(start_idx, end_idx + 1))) - self.layers.extend(list(range(listRightIndex(layer_list, config.repeats[-1][1]), len(layer_list)))) + self.layers_list.extend(list(range(start_idx, end_idx + 1))) + self.layers_list.extend(list(range(listRightIndex(layer_list, config.repeats[-1][1]), len(layer_list)))) # If we have create a Frankenmerge, lets print it to verify! - for layer in self.layers: - print(layer, self.modules[layer].key) + print("Frankenstein Layers list:") + for i, layer in enumerate(self.layers_list): + print(i, self.modules[layer].key) def set_device_map(self, allocation, embed_cpu = True): @@ -639,7 +640,7 @@ def process_module(module, x, last_state): last_state = None if hasattr(self, 'layers'): - for i, idx in enumerate(self.layers): + for i, idx in enumerate(self.layers_list): module = self.modules[idx] x, last_state = process_module(module, x, last_state) if preprocess_only and idx == self.last_kv_layer_idx: From 3afecf7361750ea38c1ae0296dce9d971f9604fc Mon Sep 17 00:00:00 2001 From: Dvid Noel Ng Date: Fri, 12 Jan 2024 17:53:11 +0100 Subject: [PATCH 4/6] Fix dumb bug --- exllamav2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index f044549c..9875f782 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -639,7 +639,7 @@ def process_module(module, x, last_state): attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets) last_state = None - if hasattr(self, 'layers'): + if hasattr(self, 'layers_list'): for i, idx in enumerate(self.layers_list): module = self.modules[idx] x, last_state = process_module(module, x, last_state) From a4e75d457c9ebde5fd39d96032044c04eee09419 Mon Sep 17 00:00:00 2001 From: Dvid Noel Ng Date: Sat, 13 Jan 2024 09:40:52 +0100 Subject: [PATCH 5/6] Updated to match the format of Mergekit --- exllamav2/model.py | 42 +++++++++++++++++------------------------ exllamav2/model_init.py | 12 +++++++++--- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 9875f782..4dc8a34a 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -158,32 +158,24 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): self.last_kv_layer_idx = layer_list - if hasattr(config, 'repeats'): - self.layers_list = [] - - def listLeftIndex(alist, value): - if value == 0: - return 0 - return alist.index(str(value)) - - def listRightIndex(alist, value): - if value > len(alist): - return -1 - return len(alist) - alist[-1::-1].index(str(value)) -1 - - layer_list = [layer.key.split(".")[-1] for layer in self.modules] - - for interval in config.repeats: - start_idx = listLeftIndex(layer_list, interval[0]) - end_idx = listRightIndex(layer_list, interval[1]) - self.layers_list.extend(list(range(start_idx, end_idx + 1))) - self.layers_list.extend(list(range(listRightIndex(layer_list, config.repeats[-1][1]), len(layer_list)))) - - # If we have create a Frankenmerge, lets print it to verify! - print("Frankenstein Layers list:") - for i, layer in enumerate(self.layers_list): - print(i, self.modules[layer].key) + embedTokenLayers = 1 + transformerSublayers = 2 + layer_arrangement = [list(range(*interval)) for interval in config.repeats] + layer_arrangement = [item for sublist in layer_arrangement for item in sublist] + LayeredModules = self.modules + + + self.modules = LayeredModules[:embedTokenLayers] + for idx in layer_arrangement: + self.modules += LayeredModules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers] + self.modules += LayeredModules[-2:] + self.head_layer_idx = len(self.modules) -1 + self.last_kv_layer_idx = len(self.modules) -4 + + for i, m in enumerate(self.modules): + print(i, m.key) + def set_device_map(self, allocation, embed_cpu = True): diff --git a/exllamav2/model_init.py b/exllamav2/model_init.py index d6c3d094..89c9708b 100644 --- a/exllamav2/model_init.py +++ b/exllamav2/model_init.py @@ -65,11 +65,17 @@ def parse_tuple_list(string): try: # Safely evaluate the string as a Python literal (list of tuples) tuple_list = ast.literal_eval(string) + + # Ensure all elements in the list are tuples if not all(isinstance(item, tuple) for item in tuple_list): - raise ValueError - return tuple_list + raise ValueError("All elements must be tuples") + + # Convert tuple elements to integers + int_tuple_list = [tuple(int(x) for x in item) for item in tuple_list] + + return int_tuple_list except: - raise argparse.ArgumentTypeError("Input must be a list of tuples") + raise argparse.ArgumentTypeError("Input must be a valid list of tuples with integer elements") def init(args, quiet = False, allow_auto_split = False, skip_load = False): From 63e5c3475f598689d6117812dd30ca175fefec95 Mon Sep 17 00:00:00 2001 From: Dvid Noel Ng Date: Sat, 13 Jan 2024 11:06:31 +0100 Subject: [PATCH 6/6] fix --- exllamav2/model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 4dc8a34a..a4346efa 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -163,19 +163,20 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False): transformerSublayers = 2 layer_arrangement = [list(range(*interval)) for interval in config.repeats] layer_arrangement = [item for sublist in layer_arrangement for item in sublist] - LayeredModules = self.modules - self.modules = LayeredModules[:embedTokenLayers] + LayeredModules = self.modules[:embedTokenLayers] for idx in layer_arrangement: - self.modules += LayeredModules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers] - self.modules += LayeredModules[-2:] + LayeredModules += self.modules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers] + LayeredModules += self.modules[-2:] self.head_layer_idx = len(self.modules) -1 self.last_kv_layer_idx = len(self.modules) -4 - for i, m in enumerate(self.modules): + for i, m in enumerate(LayeredModules): print(i, m.key) + self.layeredModules = LayeredModules + def set_device_map(self, allocation, embed_cpu = True): @@ -631,9 +632,8 @@ def process_module(module, x, last_state): attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets) last_state = None - if hasattr(self, 'layers_list'): - for i, idx in enumerate(self.layers_list): - module = self.modules[idx] + if hasattr(self, 'layeredModules'): + for idx, module in enumerate(self.layeredModules): x, last_state = process_module(module, x, last_state) if preprocess_only and idx == self.last_kv_layer_idx: x = None