diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 4370dceb..26f127cb 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -840,7 +840,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para # SDPA - if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping: + if has_lower_right_sdpa and not cfg.no_sdpa and not cfg.attn_logit_softcapping: k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) @@ -849,7 +849,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para k_states = k_states[:, :, -self.sliding_window:, :] v_states = v_states[:, :, -self.sliding_window:, :] - attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) + if attn_params.is_causal(): + attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) + else: + attn_mask_lr = attn_params.get_attn_mask(q_states.device) attn_output = F.scaled_dot_product_attention( q_states, k_states, diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index 555db71a..9f795c8f 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -400,10 +400,17 @@ def prep_logit_filter(lf): pass_tokens = None end_tokens = None - for f in filters: + pts = [] + ets = [] + for f in filters: pt, et = f.get_next() - if len(filters) > 1 and not isinstance(pt, set): + if pt is not None: + pts.append(pt) + ets.append(et) + + for pt, et in zip(pts, ets): + if len(pts) > 1 and not isinstance(pt, set): pt, et = set(pt), set(et) if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt @@ -425,7 +432,7 @@ def prep_logit_filter(lf): if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens: pass_tokens_list = [tokenizer.eos_token_id] logit_filter = prep_logit_filter(logit_filter) - ext_c.logit_filter_exclusive(logit_filter, pass_tokens_list) + ext_c.logit_filter_exclusive(logit_filter, [pass_tokens_list]) else: logit_filter = prep_logit_filter(logit_filter) if isinstance(pass_tokens, set): diff --git a/exllamav2/tokenizer/tokenizer.py b/exllamav2/tokenizer/tokenizer.py index 338a90c7..9e8d8d09 100644 --- a/exllamav2/tokenizer/tokenizer.py +++ b/exllamav2/tokenizer/tokenizer.py @@ -8,6 +8,16 @@ ExLlamaV2TokenizerSPM, ExLlamaV2TokenizerHF ) +import threading + + +lock = threading.RLock() +def synchronized_init(func): + def wrapper(*args, **kwargs): + with lock: + return func(*args, **kwargs) + return wrapper + class ExLlamaV2Tokenizer: @@ -20,7 +30,6 @@ def __init__(self, children = None, leaf = None): self.children = children if children is not None else {} self.leaf = leaf if leaf is not None else [] - config: ExLlamaV2Config tokenizer_model: ExLlamaV2TokenizerBase @@ -567,8 +576,8 @@ def num_tokens(self, text): # Get ordinals of single-byte tokens + @synchronized_init def get_id_to_ord_list(self): - if self.id_to_ord is not None: return self.id_to_ord self.id_to_ord = [] @@ -594,6 +603,7 @@ def get_id_to_ord_list(self): # Copy vocabulary from model + @synchronized_init def get_id_to_piece_list(self, include_special_tokens = False): if include_special_tokens: @@ -633,6 +643,7 @@ def get_id_to_piece_list(self, include_special_tokens = False): return self.id_to_piece + @synchronized_init def get_piece_to_id_dict(self): if self.piece_to_id is not None: return self.piece_to_id @@ -644,6 +655,7 @@ def get_piece_to_id_dict(self): # Create dictionary mapping prefixes to token IDs + @synchronized_init def get_prefix_to_ids_dict(self): if self.prefix_to_ids is not None: return self.prefix_to_ids @@ -671,6 +683,7 @@ def get_prefix_to_ids_dict(self): # Create dictionary mapping each ID to any IDs that it prefixes + @synchronized_init def get_prefix_id_to_ids_dict(self): if self.prefix_id_to_ids is not None: return self.prefix_id_to_ids @@ -712,6 +725,7 @@ def _make_trie(self, ci): return trie + @synchronized_init def get_char_trie(self): if self.char_trie is not None: return self.char_trie @@ -720,6 +734,7 @@ def get_char_trie(self): return self.char_trie + @synchronized_init def get_char_trie_ci(self): if self.char_trie_ci is not None: return self.char_trie_ci diff --git a/exllamav2/version.py b/exllamav2/version.py index 77648b6b..984fc572 100644 --- a/exllamav2/version.py +++ b/exllamav2/version.py @@ -1 +1 @@ -__version__ = "0.2.1" \ No newline at end of file +__version__ = "0.2.2" \ No newline at end of file