From f339be6e2df3986107c368cfbab529fca9a571c4 Mon Sep 17 00:00:00 2001 From: Robin San Roman Date: Fri, 11 Aug 2023 08:57:38 -0700 Subject: [PATCH] add device --- models/multibanddiffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/multibanddiffusion.py b/models/multibanddiffusion.py index 6a2f169d..1121d2fc 100644 --- a/models/multibanddiffusion.py +++ b/models/multibanddiffusion.py @@ -74,7 +74,7 @@ def get_mbd_musicgen(device=None): models, processors, cfgs = load_diffusion_models(path, device=device) DPs = [] for i in range(len(models)): - schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i]) + schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) @@ -106,7 +106,7 @@ def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True, models, processors, cfgs = load_diffusion_models(path, device=device) DPs = [] for i in range(len(models)): - schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i]) + schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)