From 1df7b04821be180d33b4c69c64e1b31c8173cde4 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Fri, 13 Sep 2024 02:20:24 +0200 Subject: [PATCH 1/5] Allow non-causal attn with SDPA --- exllamav2/attn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 4370dceb..26f127cb 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -840,7 +840,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para # SDPA - if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping: + if has_lower_right_sdpa and not cfg.no_sdpa and not cfg.attn_logit_softcapping: k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) @@ -849,7 +849,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para k_states = k_states[:, :, -self.sliding_window:, :] v_states = v_states[:, :, -self.sliding_window:, :] - attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) + if attn_params.is_causal(): + attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) + else: + attn_mask_lr = attn_params.get_attn_mask(q_states.device) attn_output = F.scaled_dot_product_attention( q_states, k_states, From 5ee983593babbbd60068fad43dd0bacd27f998d6 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 14 Sep 2024 19:28:02 +0200 Subject: [PATCH 2/5] Fix potential race condition with multithreaded sampling and lazy tokenizer initialization --- exllamav2/tokenizer/tokenizer.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/exllamav2/tokenizer/tokenizer.py b/exllamav2/tokenizer/tokenizer.py index 338a90c7..9e8d8d09 100644 --- a/exllamav2/tokenizer/tokenizer.py +++ b/exllamav2/tokenizer/tokenizer.py @@ -8,6 +8,16 @@ ExLlamaV2TokenizerSPM, ExLlamaV2TokenizerHF ) +import threading + + +lock = threading.RLock() +def synchronized_init(func): + def wrapper(*args, **kwargs): + with lock: + return func(*args, **kwargs) + return wrapper + class ExLlamaV2Tokenizer: @@ -20,7 +30,6 @@ def __init__(self, children = None, leaf = None): self.children = children if children is not None else {} self.leaf = leaf if leaf is not None else [] - config: ExLlamaV2Config tokenizer_model: ExLlamaV2TokenizerBase @@ -567,8 +576,8 @@ def num_tokens(self, text): # Get ordinals of single-byte tokens + @synchronized_init def get_id_to_ord_list(self): - if self.id_to_ord is not None: return self.id_to_ord self.id_to_ord = [] @@ -594,6 +603,7 @@ def get_id_to_ord_list(self): # Copy vocabulary from model + @synchronized_init def get_id_to_piece_list(self, include_special_tokens = False): if include_special_tokens: @@ -633,6 +643,7 @@ def get_id_to_piece_list(self, include_special_tokens = False): return self.id_to_piece + @synchronized_init def get_piece_to_id_dict(self): if self.piece_to_id is not None: return self.piece_to_id @@ -644,6 +655,7 @@ def get_piece_to_id_dict(self): # Create dictionary mapping prefixes to token IDs + @synchronized_init def get_prefix_to_ids_dict(self): if self.prefix_to_ids is not None: return self.prefix_to_ids @@ -671,6 +683,7 @@ def get_prefix_to_ids_dict(self): # Create dictionary mapping each ID to any IDs that it prefixes + @synchronized_init def get_prefix_id_to_ids_dict(self): if self.prefix_id_to_ids is not None: return self.prefix_id_to_ids @@ -712,6 +725,7 @@ def _make_trie(self, ci): return trie + @synchronized_init def get_char_trie(self): if self.char_trie is not None: return self.char_trie @@ -720,6 +734,7 @@ def get_char_trie(self): return self.char_trie + @synchronized_init def get_char_trie_ci(self): if self.char_trie_ci is not None: return self.char_trie_ci From aadc454183781b849bf0e40f2fb04561d5ba4bae Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 14 Sep 2024 19:36:53 +0200 Subject: [PATCH 3/5] Fix sampling using multiple filters --- exllamav2/generator/sampler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index 555db71a..642a91ad 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -404,7 +404,8 @@ def prep_logit_filter(lf): pt, et = f.get_next() if len(filters) > 1 and not isinstance(pt, set): - pt, et = set(pt), set(et) + if pt is not None: pt = set(pt) + if et is not None: et = set(et) if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt if et is not None: end_tokens = et if end_tokens is None else end_tokens | et @@ -425,7 +426,7 @@ def prep_logit_filter(lf): if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens: pass_tokens_list = [tokenizer.eos_token_id] logit_filter = prep_logit_filter(logit_filter) - ext_c.logit_filter_exclusive(logit_filter, pass_tokens_list) + ext_c.logit_filter_exclusive(logit_filter, [pass_tokens_list]) else: logit_filter = prep_logit_filter(logit_filter) if isinstance(pass_tokens, set): From a372fe124194065d012351a43611204e71ccb533 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 14 Sep 2024 19:42:33 +0200 Subject: [PATCH 4/5] Skip superfluous set creation when possible, even if multiple filters used --- exllamav2/generator/sampler.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index 642a91ad..9f795c8f 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -400,12 +400,18 @@ def prep_logit_filter(lf): pass_tokens = None end_tokens = None - for f in filters: + pts = [] + ets = [] + for f in filters: pt, et = f.get_next() - if len(filters) > 1 and not isinstance(pt, set): - if pt is not None: pt = set(pt) - if et is not None: et = set(et) + if pt is not None: + pts.append(pt) + ets.append(et) + + for pt, et in zip(pts, ets): + if len(pts) > 1 and not isinstance(pt, set): + pt, et = set(pt), set(et) if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt if et is not None: end_tokens = et if end_tokens is None else end_tokens | et From 228ba34cec1bf60d553e10baff1835851556922f Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sat, 14 Sep 2024 21:13:22 +0200 Subject: [PATCH 5/5] Bump to 0.2.2 --- exllamav2/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exllamav2/version.py b/exllamav2/version.py index 77648b6b..984fc572 100644 --- a/exllamav2/version.py +++ b/exllamav2/version.py @@ -1 +1 @@ -__version__ = "0.2.1" \ No newline at end of file +__version__ = "0.2.2" \ No newline at end of file