diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index b60d544c2..271a5b36e 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -33,15 +33,14 @@ def create_validation_dataloader(data, seqlen, device): @torch.no_grad() def model_eval(model, valenc, seqlen): - nsamples = len(valenc) - + dev = next(iter(model.parameters())).device with torch.no_grad(): nlls = [] for inps in valenc: lm_logits = model(**inps)['logits'] shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inps['input_ids'][:, 1:].to(model.device) + shift_labels = inps['input_ids'][:, 1:].to(dev) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) neg_log_likelihood = loss.float() * seqlen diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b033f0e26..5237c31c7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -278,7 +278,8 @@ def main(): nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0) val_data = get_wikitext2( nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0) - val_data = create_validation_dataloader(val_data, args.seqlen, model.device) + device = next(iter(model.parameters())).device + val_data = create_validation_dataloader(val_data, args.seqlen, device) print("Data loaded.") # Apply LN affine merging before inserting MHA layers