diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index a9b2b4b64..271a5b36e 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -21,11 +21,11 @@ from tqdm import tqdm -def create_validation_dataloader(data, seqlen): +def create_validation_dataloader(data, seqlen, device): nsamples = data['input_ids'].numel() // seqlen val_dataloader = [] for i in tqdm(range(nsamples)): - batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].cuda() + batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)].to(device) attention_mask = torch.ones_like(batch) val_dataloader.append({'input_ids': batch, 'attention_mask': attention_mask}) return val_dataloader @@ -33,15 +33,14 @@ def create_validation_dataloader(data, seqlen): @torch.no_grad() def model_eval(model, valenc, seqlen): - nsamples = len(valenc) - + dev = next(iter(model.parameters())).device with torch.no_grad(): nlls = [] for inps in valenc: lm_logits = model(**inps)['logits'] shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = inps['input_ids'][:, 1:].cuda() + shift_labels = inps['input_ids'][:, 1:].to(dev) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) neg_log_likelihood = loss.float() * seqlen diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index e86a8d4ba..5237c31c7 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -278,7 +278,8 @@ def main(): nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, seed=0) val_data = get_wikitext2( nsamples=args.nsamples, tokenizer=tokenizer, seqlen=args.seqlen, split='validation', seed=0) - val_data = create_validation_dataloader(val_data, args.seqlen) + device = next(iter(model.parameters())).device + val_data = create_validation_dataloader(val_data, args.seqlen, device) print("Data loaded.") # Apply LN affine merging before inserting MHA layers