Skip to content

Commit

Permalink
add timestep_spacing to params
Browse files Browse the repository at this point in the history
  • Loading branch information
noskill committed May 13, 2024
1 parent d9172e9 commit df82f01
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def try_set_scheduler(self, inputs):

def load_lora(self, path, multiplier=1.0):
self.pipe.load_lora_weights(path)
self.pipe_params['cross_attention_kwargs'] = {"scale": multiplier}
self.pipe.fuse_lora(lora_scale=multiplier)

def add_hypernet(self, path, multiplier=None):
from . hypernet import add_hypernet, clear_hypernets, Hypernetwork
Expand All @@ -150,6 +150,8 @@ def clear_hypernets(self):
def get_config(self):
cfg = {"hypernetworks": self.hypernets }
cfg.update({"model_id": self.model_id })
cfg['scheduler'] = dict(self.pipe.scheduler.config)
cfg['scheduler']['class_name'] = self.pipe.scheduler.__class__.__name__
cfg.update(self.pipe_params)
return cfg

Expand All @@ -159,8 +161,10 @@ def setup(self, steps=50, clip_skip=0, **args):
assert clip_skip <= 10
self.pipe_params['clip_skip'] = clip_skip
if 'scheduler' in args:
# TODO? add scheduler to config?
self.try_set_scheduler(dict(scheduler=args['scheduler']))
self.try_set_scheduler(args)
if 'timestep_spacing' in args:
self.pipe.scheduler = self.pipe.scheduler.from_config(self.pipe.scheduler.config, timestep_spacing = args['timestep_spacing'])
args.pop('timestep_spacing')

def from_pipe(self, pipe, **args):
if isinstance(pipe, StableDiffusionXLPipeline):
Expand Down

0 comments on commit df82f01

Please sign in to comment.