Skip to content

Commit

Permalink
fix device issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed May 3, 2024
1 parent bfc1136 commit 650c236
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ def calculate_qparams(
"""
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
device = min_vals.device

bit_range = 2**quantization_args.num_bits - 1
bit_min = -(bit_range + 1) / 2
bit_max = bit_min + bit_range
if quantization_args.symmetric:
zero_points = torch.tensor(0).to(torch.int8)
zero_points = torch.tensor(0, device=device).to(torch.int8)
max_val_pos = torch.max(-min_vals, max_vals)
scales = max_val_pos / (float(bit_range) / 2)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
Expand Down

0 comments on commit 650c236

Please sign in to comment.