diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 4ee3992a3..7a5522f83 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -284,7 +284,7 @@ def bias_logits( A view of the original logits tensor where some values are masked. """ - biased_logits = torch.empty(logits.shape) + biased_logits = torch.empty(logits.shape, device=logits.device) for i, ids in enumerate(ids_to_mask): mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device) mask[ids] = 0