From 0648ba79d98656b4b483637811c20da911ba5025 Mon Sep 17 00:00:00 2001 From: Wojtek Kowaluk Date: Thu, 6 Apr 2023 17:19:38 +0200 Subject: [PATCH] Fixes to run on CPU and MPS --- kandinsky2/kandinsky2_1_model.py | 6 ++++-- kandinsky2/model/gaussian_diffusion.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/kandinsky2/kandinsky2_1_model.py b/kandinsky2/kandinsky2_1_model.py index 2ae2ddd..f73f83e 100644 --- a/kandinsky2/kandinsky2_1_model.py +++ b/kandinsky2/kandinsky2_1_model.py @@ -30,6 +30,8 @@ def __init__( ): self.config = config self.device = device + if not torch.has_cuda: + self.config["model_config"]["use_fp16"] = False self.use_fp16 = self.config["model_config"]["use_fp16"] self.task_type = task_type self.clip_image_size = config["clip_image_size"] @@ -54,7 +56,7 @@ def __init__( clip_mean, clip_std, ) - self.prior.load_state_dict(torch.load(prior_path), strict=False) + self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False) if self.use_fp16: self.prior = self.prior.half() self.text_encoder = TextEncoder(**self.config["text_enc_params"]) @@ -88,7 +90,7 @@ def __init__( self.config["model_config"]["cache_text_emb"] = True self.model = create_model(**self.config["model_config"]) - self.model.load_state_dict(torch.load(model_path)) + self.model.load_state_dict(torch.load(model_path, map_location='cpu')) if self.use_fp16: self.model.convert_to_fp16() self.image_encoder = self.image_encoder.half() diff --git a/kandinsky2/model/gaussian_diffusion.py b/kandinsky2/model/gaussian_diffusion.py index b5449e1..1a8d2b0 100644 --- a/kandinsky2/model/gaussian_diffusion.py +++ b/kandinsky2/model/gaussian_diffusion.py @@ -822,7 +822,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = th.from_numpy(arr).to(dtype=th.float32).to(device=timesteps.device)[timesteps] while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape)