Skip to content

Commit

Permalink
Add torch_device to the VE pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
anton-l committed Jul 21, 2022
1 parent a73ae3e commit 7c0a861
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,23 @@ def __init__(self, model, scheduler):
self.register_modules(model=model, scheduler=scheduler)

@torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

img_size = self.model.config.sample_size
shape = (1, 3, img_size, img_size)
shape = (batch_size, 3, img_size, img_size)

model = self.model.to(device)
model = self.model.to(torch_device)

sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(device)
sample = sample.to(torch_device)

self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)

for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)

# correction step
for _ in range(self.scheduler.correct_steps):
Expand Down

0 comments on commit 7c0a861

Please sign in to comment.