Skip to content

Commit

Permalink
Further optimize mask_logits
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Sep 12, 2024
1 parent d30dff6 commit bd3f363
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/kbnf/engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import types
import typing
import importlib
Expand Down Expand Up @@ -40,7 +41,7 @@ def convert_slice(tensor:typing.Any)->typing.Optional[typing.Tuple[typing.Any,in
return convert_slice

def _torch_fast_mask_logits(module:types.ModuleType):
ninf = -float("inf")
ninf = -math.inf
def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing.Any]:
if isinstance(tensor, module.Tensor):
assert tensor.dim() == 1, f"Only 1D tensor is supported, while the actual tensor shape is {tensor.shape}"
Expand All @@ -51,11 +52,13 @@ def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing
if length == 0: # Rust FFI requires non-null pointer
return tensor
disallowed = module.empty((length,), device="cpu",dtype=module.int64)
disallowed.pin_memory()
data_ptr = disallowed.data_ptr()
assert data_ptr % 8 == 0, f"The indices data pointer which points to {data_ptr} is not aligned to 8 bytes"
engine.write_disallowed_token_ids_to_buffer(data_ptr, length)
length = engine.get_number_of_allowed_token_ids()
allowed = module.empty((length,), device="cpu",dtype=module.int64)
allowed.pin_memory()
data_ptr = allowed.data_ptr()
assert data_ptr % 8 == 0, f"The allowed data pointer which points to {data_ptr} is not aligned to 8 bytes"
engine.write_allowed_token_ids_to_buffer(data_ptr, length)
Expand All @@ -64,11 +67,11 @@ def mask_logits_fast(tensor:typing.Any, engine:"Engine")->typing.Optional[typing
disallowed, allowed = engine._cache[index]
if num_of_disallowed>tensor.shape[-1]/2: # we have more disallowed than allowed
new_tensor = module.full_like(tensor,fill_value=ninf)
allowed = allowed.to(device=tensor.device)
allowed = allowed.to(device=tensor.device,non_blocking=True)
new_tensor.put_(allowed, tensor.take(allowed))
return new_tensor
else: # we have more allowed than disallowed
tensor.index_fill_(0,disallowed.to(device=tensor.device),ninf)
tensor.index_fill_(0,disallowed.to(device=tensor.device,non_blocking=True),ninf)
return tensor
return None
return mask_logits_fast
Expand Down

0 comments on commit bd3f363

Please sign in to comment.