From 934ad4e1f6235d23b475fb12857bf7347732642d Mon Sep 17 00:00:00 2001 From: proximasan Date: Fri, 3 Nov 2023 22:34:47 +0100 Subject: [PATCH] move untruncated_ids to GPU to fix tensor device mismatch error --- examples/dreambooth/train_dreambooth_sdxl_TI.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth_sdxl_TI.py b/examples/dreambooth/train_dreambooth_sdxl_TI.py index a1d9c815065e..a9ad7f67fa3b 100644 --- a/examples/dreambooth/train_dreambooth_sdxl_TI.py +++ b/examples/dreambooth/train_dreambooth_sdxl_TI.py @@ -456,6 +456,7 @@ def encode_prompt(text_encoders, tokenizers, prompt): ) text_input_ids = text_inputs.input_ids.to(text_encoder.device) untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = untruncated_ids.to('cuda') if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])