diff --git a/pyproject.toml b/pyproject.toml index c4c70af..1f69004 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ python-source = "python" features = ["python"] [project] name = "kbnf" -version = "0.3.8" +version = "0.3.9" dependencies = ["numpy"] requires-python = ">=3.7" classifiers = [ diff --git a/python/kbnf/engine.py b/python/kbnf/engine.py index 4878015..69c78db 100644 --- a/python/kbnf/engine.py +++ b/python/kbnf/engine.py @@ -40,13 +40,11 @@ def convert_slice(tensor:typing.Any)->typing.Optional[typing.Tuple[typing.Any,in return convert_slice def _torch_fast_mask_logits(module:types.ModuleType): - cache: typing.Dict[bytes, module.Tensor] = {} - cache = {} def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing.Any]: if isinstance(tensor, module.Tensor): assert tensor.dim() == 1, f"Only 1D tensor is supported, while the actual tensor shape is {tensor.shape}" index = engine.get_index_of_allowed_token_ids() - if index not in cache: + if index not in engine._cache: length = engine.get_number_of_disallowed_token_ids() if length == 0: # Rust FFI requires non-null pointer return tensor @@ -54,9 +52,9 @@ def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing data_ptr = indices.data_ptr() assert data_ptr % 8 == 0, f"The indices data pointer which points to {data_ptr} is not aligned to 8 bytes" engine.write_disallowed_token_ids_to_buffer(data_ptr, length) - cache[index] = indices + engine._cache[index] = indices else: - indices = cache[index] + indices = engine._cache[index] tensor.index_fill_(0,indices.to(device=tensor.device),-float("inf")) return tensor return None @@ -90,6 +88,10 @@ def _mask_logits_fast(logits:typing.Any,engine:"Engine")->typing.Optional[typing return None class Engine(InternalEngine): + def __init__(self, kbnf_syntax_grammar_str, vocabulary,config=None): # signature is only needed for python runtime type checking + super().__init__() # pyo3 works by making magics on __new__ and the __init__ is just a placeholder + self._cache = {} + def mask_logits(self, logits): """ Masks the logits based on last computed token IDs.