Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang <[email protected]>
  • Loading branch information
changwangss authored Aug 23, 2024
1 parent 408e5f1 commit e135523
Showing 1 changed file with 1 addition and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,6 @@ def default_run_fn(
model, tokenizer, dataset, max_length=512, n_samples=100, batch_size=8, algo="rtn"
):
from torch.utils.data import DataLoader

if isinstance(dataset, (str, bytes, os.PathLike)):
calib_dataset = load_dataset(dataset, split="train")
calib_dataset = calib_dataset.shuffle(seed=42)
Expand Down Expand Up @@ -513,7 +512,7 @@ def collate_batch(batch):

try:
model(
input_ids=input_ids,
input_ids=input_ids.to("xpu:0"),
)
except ValueError:
pass
Expand Down

0 comments on commit e135523

Please sign in to comment.