Skip to content

Commit

Permalink
removed batch_name to fix a KeyError with "uttid" (#1172)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Jul 15, 2023
1 parent 5ed6fc0 commit 4ab7d61
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions egs/librispeech/ASR/conformer_ctc2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,6 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
batch_name = batch["supervisions"]["uttid"]

with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
Expand All @@ -698,10 +697,7 @@ def train_one_epoch(
scaler.scale(loss).backward()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
f"failing batch size:{batch_size} "
f"failing batch names {batch_name}"
)
logging.error(f"failing batch size:{batch_size} ")
raise

scheduler.step_batch(params.batch_idx_train)
Expand Down Expand Up @@ -756,10 +752,7 @@ def train_one_epoch(
if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float(
"inf"
):
logging.error(
"Your loss contains inf, something goes wrong"
f"failing batch names {batch_name}"
)
logging.error("Your loss contains inf, something goes wrong")
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
Expand Down

0 comments on commit 4ab7d61

Please sign in to comment.