diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index a8a5b88baf9..d23971cfc11 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -438,7 +438,6 @@ def default_run_fn( model, tokenizer, dataset, max_length=512, n_samples=100, batch_size=8, algo="rtn" ): from torch.utils.data import DataLoader - if isinstance(dataset, (str, bytes, os.PathLike)): calib_dataset = load_dataset(dataset, split="train") calib_dataset = calib_dataset.shuffle(seed=42) @@ -513,7 +512,7 @@ def collate_batch(batch): try: model( - input_ids=input_ids, + input_ids=input_ids.to("xpu:0"), ) except ValueError: pass