Skip to content

Commit

Permalink
Bug fix for multiple different vocabularies
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Sep 10, 2024
1 parent 3bc0718 commit 210ebe2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ python-source = "python"
features = ["python"]
[project]
name = "kbnf"
version = "0.3.8"
version = "0.3.9"
dependencies = ["numpy"]
requires-python = ">=3.7"
classifiers = [
Expand Down
12 changes: 7 additions & 5 deletions python/kbnf/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,21 @@ def convert_slice(tensor:typing.Any)->typing.Optional[typing.Tuple[typing.Any,in
return convert_slice

def _torch_fast_mask_logits(module:types.ModuleType):
cache: typing.Dict[bytes, module.Tensor] = {}
cache = {}
def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing.Any]:
if isinstance(tensor, module.Tensor):
assert tensor.dim() == 1, f"Only 1D tensor is supported, while the actual tensor shape is {tensor.shape}"
index = engine.get_index_of_allowed_token_ids()
if index not in cache:
if index not in engine._cache:
length = engine.get_number_of_disallowed_token_ids()
if length == 0: # Rust FFI requires non-null pointer
return tensor
indices = module.empty((length,), device="cpu",dtype=module.int64)
data_ptr = indices.data_ptr()
assert data_ptr % 8 == 0, f"The indices data pointer which points to {data_ptr} is not aligned to 8 bytes"
engine.write_disallowed_token_ids_to_buffer(data_ptr, length)
cache[index] = indices
engine._cache[index] = indices
else:
indices = cache[index]
indices = engine._cache[index]
tensor.index_fill_(0,indices.to(device=tensor.device),-float("inf"))
return tensor
return None
Expand Down Expand Up @@ -90,6 +88,10 @@ def _mask_logits_fast(logits:typing.Any,engine:"Engine")->typing.Optional[typing
return None

class Engine(InternalEngine):
def __init__(self, kbnf_syntax_grammar_str, vocabulary,config=None): # signature is only needed for python runtime type checking
super().__init__() # pyo3 works by making magics on __new__ and the __init__ is just a placeholder
self._cache = {}

def mask_logits(self, logits):
"""
Masks the logits based on last computed token IDs.
Expand Down

0 comments on commit 210ebe2

Please sign in to comment.