Skip to content

Commit

Permalink
drafting
Browse files Browse the repository at this point in the history
  • Loading branch information
riship committed Jan 24, 2025
1 parent 69438de commit f860da7
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions examples/llm/tech_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,13 @@ def eval(question: str, pred: str, correct_answer: str):
return llm_judge.score(question, pred, correct_answer)

scores = []
for test_batch in test_loader:
preds = inference_step(model, test_batch)
"""
TQDM here since only 1 mini-batch for testing
since theres only 11 test questions and eval-batch-size
is 16 by default
"""
for question, pred, label in tqdm(list(zip(test_batch.question, preds, test_batch.label)), desc="Test:"):
scores.append(eval(question, pred, label))
eval_tuples = []
for test_batch in tqdm(test_loader, desc="Testing"):
pred = (inference_step(model, test_batch))
for question, pred, label in zip(test_batch.question, preds, test_batch.label):
eval_tuples.append((question, pred, label))
for question, pred, label in tqdm(eval_tuples, desc="Eval"):
scores.append(eval(question, pred, label))
avg_scores = sum(scores) / len(scores)
print("Avg marlin accuracy=", avg_scores)

Expand Down

0 comments on commit f860da7

Please sign in to comment.