Skip to content

Commit

Permalink
Merge branch 'refs/heads/dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Sep 14, 2024
2 parents 1e18e80 + 228ba34 commit 46eff43
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
7 changes: 5 additions & 2 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para

# SDPA

if has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping:
if has_lower_right_sdpa and not cfg.no_sdpa and not cfg.attn_logit_softcapping:

k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
Expand All @@ -849,7 +849,10 @@ def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_para
k_states = k_states[:, :, -self.sliding_window:, :]
v_states = v_states[:, :, -self.sliding_window:, :]

attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
if attn_params.is_causal():
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
else:
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
attn_output = F.scaled_dot_product_attention(
q_states,
k_states,
Expand Down
13 changes: 10 additions & 3 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,17 @@ def prep_logit_filter(lf):

pass_tokens = None
end_tokens = None
for f in filters:

pts = []
ets = []
for f in filters:
pt, et = f.get_next()
if len(filters) > 1 and not isinstance(pt, set):
if pt is not None:
pts.append(pt)
ets.append(et)

for pt, et in zip(pts, ets):
if len(pts) > 1 and not isinstance(pt, set):
pt, et = set(pt), set(et)

if pt is not None: pass_tokens = pt if pass_tokens is None else pass_tokens & pt
Expand All @@ -425,7 +432,7 @@ def prep_logit_filter(lf):
if filter_prefer_eos and tokenizer.eos_token_id in pass_tokens:
pass_tokens_list = [tokenizer.eos_token_id]
logit_filter = prep_logit_filter(logit_filter)
ext_c.logit_filter_exclusive(logit_filter, pass_tokens_list)
ext_c.logit_filter_exclusive(logit_filter, [pass_tokens_list])
else:
logit_filter = prep_logit_filter(logit_filter)
if isinstance(pass_tokens, set):
Expand Down
19 changes: 17 additions & 2 deletions exllamav2/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
ExLlamaV2TokenizerSPM,
ExLlamaV2TokenizerHF
)
import threading


lock = threading.RLock()
def synchronized_init(func):
def wrapper(*args, **kwargs):
with lock:
return func(*args, **kwargs)
return wrapper


class ExLlamaV2Tokenizer:

Expand All @@ -20,7 +30,6 @@ def __init__(self, children = None, leaf = None):
self.children = children if children is not None else {}
self.leaf = leaf if leaf is not None else []


config: ExLlamaV2Config

tokenizer_model: ExLlamaV2TokenizerBase
Expand Down Expand Up @@ -567,8 +576,8 @@ def num_tokens(self, text):

# Get ordinals of single-byte tokens

@synchronized_init
def get_id_to_ord_list(self):

if self.id_to_ord is not None: return self.id_to_ord

self.id_to_ord = []
Expand All @@ -594,6 +603,7 @@ def get_id_to_ord_list(self):

# Copy vocabulary from model

@synchronized_init
def get_id_to_piece_list(self, include_special_tokens = False):

if include_special_tokens:
Expand Down Expand Up @@ -633,6 +643,7 @@ def get_id_to_piece_list(self, include_special_tokens = False):
return self.id_to_piece


@synchronized_init
def get_piece_to_id_dict(self):

if self.piece_to_id is not None: return self.piece_to_id
Expand All @@ -644,6 +655,7 @@ def get_piece_to_id_dict(self):

# Create dictionary mapping prefixes to token IDs

@synchronized_init
def get_prefix_to_ids_dict(self):

if self.prefix_to_ids is not None: return self.prefix_to_ids
Expand Down Expand Up @@ -671,6 +683,7 @@ def get_prefix_to_ids_dict(self):

# Create dictionary mapping each ID to any IDs that it prefixes

@synchronized_init
def get_prefix_id_to_ids_dict(self):

if self.prefix_id_to_ids is not None: return self.prefix_id_to_ids
Expand Down Expand Up @@ -712,6 +725,7 @@ def _make_trie(self, ci):
return trie


@synchronized_init
def get_char_trie(self):

if self.char_trie is not None: return self.char_trie
Expand All @@ -720,6 +734,7 @@ def get_char_trie(self):
return self.char_trie


@synchronized_init
def get_char_trie_ci(self):

if self.char_trie_ci is not None: return self.char_trie_ci
Expand Down
2 changes: 1 addition & 1 deletion exllamav2/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1"
__version__ = "0.2.2"

0 comments on commit 46eff43

Please sign in to comment.