Skip to content

Commit

Permalink
chage float16 to float for workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad committed Jun 4, 2024
1 parent dca899d commit 1d69841
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def wrapper(*args, **kwargs):
with torch.autocast("cuda", dtype=torch.bfloat16):
return func(*args, **kwargs)
else:
with torch.autocast("cuda", dtype=torch.float16):
# TODO: We are using float instead of float16 due to strange bug
with torch.autocast("cuda", dtype=torch.float):
return func(*args, **kwargs)

return wrapper
Expand Down Expand Up @@ -220,6 +221,7 @@ def process_segments(
[], device=seq.device, dtype=torch.int
),
)
assert not torch.isnan(logits).any(), "NaN seen in logits"

logits[:, 389] *= 1.05
next_tok_ids = torch.argmax(logits, dim=-1)
Expand Down Expand Up @@ -274,7 +276,7 @@ def gpu_manager(
model.decoder.setup_cache(
batch_size=batch_size,
max_seq_len=MAX_BLOCK_LEN,
dtype=torch.bfloat16 if is_bf16_supported() else torch.float16,
dtype=torch.bfloat16 if is_bf16_supported() else torch.float,
)
model.cuda()
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
aria @ git+https://github.com/EleutherAI/aria.git
torch == 2.2
torch >= 2.2
torchaudio
accelerate
psutil
Expand Down

0 comments on commit 1d69841

Please sign in to comment.