diff --git a/multigen/pipes.py b/multigen/pipes.py index a611b55..a8b68b2 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -131,7 +131,7 @@ def try_set_scheduler(self, inputs): def load_lora(self, path, multiplier=1.0): self.pipe.load_lora_weights(path) - self.pipe_params['cross_attention_kwargs'] = {"scale": multiplier} + self.pipe.fuse_lora(lora_scale=multiplier) def add_hypernet(self, path, multiplier=None): from . hypernet import add_hypernet, clear_hypernets, Hypernetwork @@ -150,6 +150,8 @@ def clear_hypernets(self): def get_config(self): cfg = {"hypernetworks": self.hypernets } cfg.update({"model_id": self.model_id }) + cfg['scheduler'] = dict(self.pipe.scheduler.config) + cfg['scheduler']['class_name'] = self.pipe.scheduler.__class__.__name__ cfg.update(self.pipe_params) return cfg @@ -159,8 +161,10 @@ def setup(self, steps=50, clip_skip=0, **args): assert clip_skip <= 10 self.pipe_params['clip_skip'] = clip_skip if 'scheduler' in args: - # TODO? add scheduler to config? - self.try_set_scheduler(dict(scheduler=args['scheduler'])) + self.try_set_scheduler(args) + if 'timestep_spacing' in args: + self.pipe.scheduler = self.pipe.scheduler.from_config(self.pipe.scheduler.config, timestep_spacing = args['timestep_spacing']) + args.pop('timestep_spacing') def from_pipe(self, pipe, **args): if isinstance(pipe, StableDiffusionXLPipeline):