Skip to content

Commit

Permalink
Remove old logic except for inpaint, add support for lora and ti to i…
Browse files Browse the repository at this point in the history
…npaint node
  • Loading branch information
StAlKeR7779 committed Jun 17, 2023
1 parent 77248c2 commit 956a6aa
Show file tree
Hide file tree
Showing 14 changed files with 64 additions and 1,824 deletions.
99 changes: 37 additions & 62 deletions invokeai/app/invocations/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import re
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, ClipField, VaeField
from .model import UNetField, VaeField
from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator

SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
Expand Down Expand Up @@ -63,19 +64,15 @@ class InpaintInvocation(BaseInvocation):

type: Literal["inpaint"] = "inpaint"

prompt: Optional[str] = Field(description="The prompt to generate an image from")
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
#model: str = Field(default="", description="The model to use (currently ignored)")
#progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
#control_model: Optional[str] = Field(default=None, description="The control model to use")
#control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
unet: UNetField = Field(default=None, description="UNet model")
clip: ClipField = Field(default=None, description="Clip model")
vae: VaeField = Field(default=None, description="Vae model")

# Inputs
Expand Down Expand Up @@ -151,64 +148,34 @@ def dispatch_progress(
source_node_id=source_node_id,
)

def get_conditioning(self, context):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)

return (uc, c, extra_conditioning_info)

@contextmanager
def load_model_old_way(self, context):
with ExitStack() as stack:
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
tokenizer_info = context.services.model_manager.get_model(**self.clip.tokenizer.dict())
text_encoder_info = context.services.model_manager.get_model(**self.clip.text_encoder.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())

#unet = stack.enter_context(unet_info)
#tokenizer = stack.enter_context(tokenizer_info)
#text_encoder = stack.enter_context(text_encoder_info)
#vae = stack.enter_context(vae_info)
with vae_info as vae:
device = vae.device
dtype = vae.dtype

# not load models to gpu as it should be handled by pipeline
unet = unet_info.context.model
tokenizer = tokenizer_info.context.model
text_encoder = text_encoder_info.context.model
vae = vae_info.context.model

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
def load_model_old_way(self, context, scheduler):
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())

#unet = unet_info.context.model
#vae = vae_info.context.model

with ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")


with ModelPatcher.apply_lora_unet(unet, loras),\
ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (ti_tokenizer, ti_manager):

with vae_info as vae,\
unet_info as unet,\
ModelPatcher.apply_lora_unet(unet, loras):

device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision

pipeline = StableDiffusionGeneratorPipeline(
# TODO: ti_manager
vae=vae,
text_encoder=text_encoder,
tokenizer=ti_tokenizer,
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
Expand Down Expand Up @@ -242,14 +209,22 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]

with self.load_model_old_way(context) as model:
conditioning = self.get_conditioning(context)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)

with self.load_model_old_way(context, scheduler) as model:
outputs = Inpaint(model).generate(
prompt=self.prompt,
conditioning=conditioning,
scheduler=scheduler,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)

Expand Down
54 changes: 17 additions & 37 deletions invokeai/backend/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP

Expand Down Expand Up @@ -81,13 +80,15 @@ def __init__(self,
self.params=params
self.kwargs = kwargs

def generate(self,
prompt: str='',
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
def generate(
self,
conditioning: tuple,
scheduler,
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
'''
Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput
Expand Down Expand Up @@ -116,11 +117,6 @@ def generate(self,
model_name = model_info.name
model_hash = model_info.hash
with model_info.context as model:
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0:
Expand All @@ -143,12 +139,12 @@ def generate(self,

iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
results = generator.generate(
conditioning=conditioning,
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
Expand All @@ -170,20 +166,6 @@ def schedulers(self)->List[str]:
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)

def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])

scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)

# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler

@classmethod
def _generator_class(cls)->Type[Generator]:
'''
Expand Down Expand Up @@ -281,7 +263,7 @@ def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.latent_channels = model.unet.config.in_channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0
Expand All @@ -292,7 +274,7 @@ def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
self.free_gpu_mem = None

# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self, prompt, **kwargs):
def get_make_image(self, **kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
Expand All @@ -308,7 +290,6 @@ def set_variation(self, seed, variation_amount, with_variations):

def generate(
self,
prompt,
width,
height,
sampler,
Expand All @@ -333,7 +314,6 @@ def generate(
saver.get_stacked_maps_image()
)
make_image = self.get_make_image(
prompt,
sampler=sampler,
init_image=init_image,
width=width,
Expand Down
Loading

0 comments on commit 956a6aa

Please sign in to comment.