Skip to content

Commit

Permalink
Fix (brevitas_examples/llm): use device for eval (#949)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob authored May 31, 2024
1 parent fc4162e commit 8c71e08
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
9 changes: 4 additions & 5 deletions src/brevitas_examples/llm/llm_quant/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,26 @@
from tqdm import tqdm


def create_validation_dataloader(data, seqlen):
def create_validation_dataloader(data, seqlen, device):
nsamples = data['input_ids'].numel() // seqlen
val_dataloader = []
for i in tqdm(range(nsamples)):
batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda()
batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].to(device)
attention_mask = torch.ones_like(batch)
val_dataloader.append({'input_ids': batch, 'attention_mask': attention_mask})
return val_dataloader


@torch.no_grad()
def model_eval(model, valenc, seqlen):

nsamples = len(valenc)

dev = next(iter(model.parameters())).device
with torch.no_grad():
nlls = []
for inps in valenc:
lm_logits = model(**inps)['logits']
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = inps['input_ids'][:, 1:].cuda()
shift_labels = inps['input_ids'][:, 1:].to(dev)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seqlen
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def main():
nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0)
val_data = get_wikitext2(
nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0)
val_data = create_validation_dataloader(val_data, args.seqlen)
device = next(iter(model.parameters())).device
val_data = create_validation_dataloader(val_data, args.seqlen, device)
print("Data loaded.")

# Apply LN affine merging before inserting MHA layers
Expand Down

0 comments on commit 8c71e08

Please sign in to comment.