Skip to content

Commit

Permalink
Fixes to run on CPU and MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
Wojtek Kowaluk authored and Wojtek Kowaluk committed Apr 7, 2023
1 parent a4354c0 commit 0648ba7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions kandinsky2/kandinsky2_1_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
):
self.config = config
self.device = device
if not torch.has_cuda:
self.config["model_config"]["use_fp16"] = False
self.use_fp16 = self.config["model_config"]["use_fp16"]
self.task_type = task_type
self.clip_image_size = config["clip_image_size"]
Expand All @@ -54,7 +56,7 @@ def __init__(
clip_mean,
clip_std,
)
self.prior.load_state_dict(torch.load(prior_path), strict=False)
self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
if self.use_fp16:
self.prior = self.prior.half()
self.text_encoder = TextEncoder(**self.config["text_enc_params"])
Expand Down Expand Up @@ -88,7 +90,7 @@ def __init__(

self.config["model_config"]["cache_text_emb"] = True
self.model = create_model(**self.config["model_config"])
self.model.load_state_dict(torch.load(model_path))
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
if self.use_fp16:
self.model.convert_to_fp16()
self.image_encoder = self.image_encoder.half()
Expand Down
2 changes: 1 addition & 1 deletion kandinsky2/model/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
res = th.from_numpy(arr).to(dtype=th.float32).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)

0 comments on commit 0648ba7

Please sign in to comment.