Skip to content

Commit

Permalink
Fix (nn): Fix in QSPDA
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Nov 28, 2024
1 parent ef1b779 commit e3d3c4a
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ def forward(
"""
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if attn_mask is None:
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
else:
attn_bias = torch.zeros(size=attn_mask.shape, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
Expand All @@ -187,7 +190,7 @@ def forward(
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_bias + attn_mask
attn_bias += attn_mask
q_scaled = self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transposed_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
Expand Down

0 comments on commit e3d3c4a

Please sign in to comment.