From 7c0a861894b1e34e6b3ae6601ac5054e3e435f29 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 21 Jul 2022 13:53:09 +0200 Subject: [PATCH] Add torch_device to the VE pipeline --- .../pipelines/score_sde_ve/pipeline_score_sde_ve.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 6344a578b9cb..0ce9626bd364 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -11,22 +11,23 @@ def __init__(self, model, scheduler): self.register_modules(model=model, scheduler=scheduler) @torch.no_grad() - def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"): + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" img_size = self.model.config.sample_size - shape = (1, 3, img_size, img_size) + shape = (batch_size, 3, img_size, img_size) - model = self.model.to(device) + model = self.model.to(torch_device) sample = torch.randn(*shape) * self.scheduler.config.sigma_max - sample = sample.to(device) + sample = sample.to(torch_device) self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) for i, t in tqdm(enumerate(self.scheduler.timesteps)): - sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device) # correction step for _ in range(self.scheduler.correct_steps):