diff --git a/python/kbnf/engine.py b/python/kbnf/engine.py index 8b436d2..e434cf0 100644 --- a/python/kbnf/engine.py +++ b/python/kbnf/engine.py @@ -1,3 +1,4 @@ +import math import types import typing import importlib @@ -40,7 +41,7 @@ def convert_slice(tensor:typing.Any)->typing.Optional[typing.Tuple[typing.Any,in return convert_slice def _torch_fast_mask_logits(module:types.ModuleType): - ninf = -float("inf") + ninf = -math.inf def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing.Any]: if isinstance(tensor, module.Tensor): assert tensor.dim() == 1, f"Only 1D tensor is supported, while the actual tensor shape is {tensor.shape}" @@ -51,11 +52,13 @@ def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing if length == 0: # Rust FFI requires non-null pointer return tensor disallowed = module.empty((length,), device="cpu",dtype=module.int64) + disallowed.pin_memory() data_ptr = disallowed.data_ptr() assert data_ptr % 8 == 0, f"The indices data pointer which points to {data_ptr} is not aligned to 8 bytes" engine.write_disallowed_token_ids_to_buffer(data_ptr, length) length = engine.get_number_of_allowed_token_ids() allowed = module.empty((length,), device="cpu",dtype=module.int64) + allowed.pin_memory() data_ptr = allowed.data_ptr() assert data_ptr % 8 == 0, f"The allowed data pointer which points to {data_ptr} is not aligned to 8 bytes" engine.write_allowed_token_ids_to_buffer(data_ptr, length) @@ -64,11 +67,11 @@ def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing disallowed, allowed = engine._cache[index] if num_of_disallowed>tensor.shape[-1]/2: # we have more disallowed than allowed new_tensor = module.full_like(tensor,fill_value=ninf) - allowed = allowed.to(device=tensor.device) + allowed = allowed.to(device=tensor.device,non_blocking=True) new_tensor.put_(allowed, tensor.take(allowed)) return new_tensor else: # we have more allowed than disallowed - tensor.index_fill_(0,disallowed.to(device=tensor.device),ninf) + tensor.index_fill_(0,disallowed.to(device=tensor.device,non_blocking=True),ninf) return tensor return None return mask_logits_fast