Skip to content

Commit

Permalink
Merge pull request #71 from noskill/lpw
Browse files Browse the repository at this point in the history
lpw tests
  • Loading branch information
Necr0x0Der authored Jul 22, 2024
2 parents 75d2935 + 26dfdd9 commit df5157c
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 11 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ jobs:
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: symlink models directory
run: ln -s ../../../../models tests/models-full
- name: Test with pytest
env:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python pipe_test.py
- name: Test worker
Expand Down
7 changes: 5 additions & 2 deletions multigen/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.fl
logger.debug(f'looking for pipeline {cls} from {path} on {device}')
result = None
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', 0)
if device.type == 'cuda':
idx = device.index
gpu_pipes = self._gpu_pipes.get(idx, [])
Expand All @@ -89,7 +89,9 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.fl
result = result.to(dtype=torch_dtype, device=device)
self.cache_pipeline(result, path)
result = copy_pipe(result)
assert result.device == device
assert result.device.type == device.type
if device.type == 'cuda':
assert result.device.index == device.index
logger.debug(f'returning {type(result)} from {path} on {result.device}')
return result

Expand Down Expand Up @@ -131,6 +133,7 @@ def clear_cache(self, device):

def _store_gpu_pipe(self, pipe, model_id):
idx = pipe.device.index
assert idx is not None
# for now just clear all other pipelines
self._gpu_pipes[idx] = [(model_id, pipe)]

Expand Down
24 changes: 24 additions & 0 deletions multigen/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -261,6 +269,7 @@ def get_weighted_text_embeddings(
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip=None,
lora_scale=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
Expand All @@ -287,6 +296,16 @@ def get_weighted_text_embeddings(
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, LoraLoaderMixin):
pipe._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
Expand Down Expand Up @@ -383,6 +402,11 @@ def get_weighted_text_embeddings(
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)

if pipe.text_encoder is not None:
if isinstance(pipe, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)

if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None
Expand Down
40 changes: 38 additions & 2 deletions multigen/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.loaders import StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
Expand All @@ -37,7 +37,14 @@
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor

from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)

if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
Expand Down Expand Up @@ -256,6 +263,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None
):
"""
This function can process long prompt with weights, no length limitation
Expand All @@ -276,6 +284,24 @@ def get_weighted_text_embeddings_sdxl(
"""
device = device or pipe._execution_device

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
pipe._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if pipe.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)

if pipe.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
else:
scale_lora_layers(pipe.text_encoder_2, lora_scale)

if prompt_2:
prompt = f"{prompt} {prompt_2}"

Expand Down Expand Up @@ -424,6 +450,16 @@ def get_weighted_text_embeddings_sdxl(
bs_embed * num_images_per_prompt, -1
)

if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)

if pipe.text_encoder_2 is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds


Expand Down
14 changes: 10 additions & 4 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self, model_id: str,
self.pipe = self._load_pipeline(sd_pipe_class, model_type, args)
self._initialize_pipe(device)
self.lpw = lpw
self._loras = []
mt = self._get_model_type()
if self.model_type is None:
self.model_type = mt
Expand Down Expand Up @@ -183,6 +184,7 @@ def load_lora(self, path, multiplier=1.0):
if 'cross_attention_kwargs' not in self.pipe_params:
self.pipe_params['cross_attention_kwargs'] = {}
self.pipe_params['cross_attention_kwargs']["scale"] = multiplier
self._loras.append(path)

def add_hypernet(self, path, multiplier=None):
from . hypernet import add_hypernet, Hypernetwork
Expand All @@ -209,6 +211,7 @@ def get_config(self):
cfg.update({"model_id": self.model_id })
cfg['scheduler'] = dict(self.pipe.scheduler.config)
cfg['scheduler']['class_name'] = self.pipe.scheduler.__class__.__name__
cfg['loras'] = self._loras
cfg.update(self.pipe_params)
return cfg

Expand All @@ -230,7 +233,7 @@ def setup(self, steps=50, clip_skip=0, loras=[], **args):
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
:return: None
"""
self.pipe_params = { 'num_inference_steps': steps }
self.pipe_params.update({ 'num_inference_steps': steps })
assert clip_skip >= 0
assert clip_skip <= 10
self.pipe_params['clip_skip'] = clip_skip
Expand All @@ -242,7 +245,7 @@ def setup(self, steps=50, clip_skip=0, loras=[], **args):
for lora in loras:
self.load_lora(lora)

def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] = None):
def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] = None, lora_scale: Optional[int] = None):
if self.lpw:
# convert to lpw
if isinstance(self.pipe, self._classxl):
Expand All @@ -255,6 +258,7 @@ def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] =
neg_prompt=negative_prompt,
num_images_per_prompt=1,
clip_skip=clip_skip,
lora_scale=lora_scale
)
elif isinstance(self.pipe, self._class):
from . import lpw_stable_diffusion
Expand All @@ -264,12 +268,14 @@ def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] =
uncond_prompt=negative_prompt,
max_embeddings_multiples=3,
clip_skip=clip_skip,
lora_scale=lora_scale
)

def prepare_inputs(self, inputs):
kwargs = self.pipe_params.copy()
kwargs.update(inputs)
if self.lpw:
lora_scale = kwargs.get('cross_attention_kwargs', dict()).get("scale", None)
if self.model_type == ModelType.SDXL:
if 'negative_prompt' not in kwargs:
kwargs['negative_prompt'] = None
Expand All @@ -283,7 +289,7 @@ def prepare_inputs(self, inputs):
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.get_prompt_embeds(kwargs.pop('prompt'), kwargs.pop('negative_prompt'), kwargs.pop('clip_skip'))
) = self.get_prompt_embeds(kwargs.pop('prompt'), kwargs.pop('negative_prompt'), kwargs.pop('clip_skip'), lora_scale=lora_scale)

kwargs['prompt_embeds'] = prompt_embeds
kwargs['negative_prompt_embeds'] = negative_prompt_embeds
Expand All @@ -294,7 +300,7 @@ def prepare_inputs(self, inputs):
prompt=kwargs.pop('prompt'),
negative_prompt=kwargs.pop('negative_prompt'),
clip_skip=kwargs.pop('clip_skip'),
)
lora_scale=lora_scale)
kwargs['prompt_embeds'] = prompt_embeds
kwargs['negative_prompt_embeds'] = negative_prompt_embeds
else:
Expand Down
5 changes: 2 additions & 3 deletions multigen/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ def __next__(self):
raise StopIteration
thread_data.random = random.Random()
seed = self.seeds[self.count % nseeds] if nseeds > 0 else \
thread_data.random.randint(1, 1024*1024*1024*4-1)
self.count += 1

random.randint(1, 1024*1024*1024*4-1)
thread_data.random.seed(seed)
self.count += 1
result = {'prompt': get_prompt(self.prompt),
'generator': seed,
'negative_prompt': get_prompt(self.negative_prompt)}
Expand Down
57 changes: 57 additions & 0 deletions tests/pipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,21 @@ def compute_diff(self, im1: PIL.Image.Image, im2: PIL.Image.Image) -> float:



def can_run_lpw():
if os.environ.get('METAFUSION_MODELS_DIR'):
return True
return False


class MyTestCase(TestCase):

def setUp(self):
self._pipeline = None

def get_model(self):
models_dir = os.environ.get('METAFUSION_MODELS_DIR', None)
if models_dir is not None:
return models_dir + '/icb_diffusers'
return "hf-internal-testing/tiny-stable-diffusion-torch"

def get_ref_image(self):
Expand Down Expand Up @@ -116,11 +125,59 @@ def test_img2img_basic(self):
result = pipe.gen(dict(prompt="cube planet cartoon style"))
result.save('test_img2img_basic.png')

@unittest.skipIf(not can_run_lpw(), "can't run on tiny version of SD")
def test_lpw(self):
"""
Check that last part of long prompt affect the generation
"""
pipe = Prompt2ImPipe(self.get_model(), model_type=self.model_type(), lpw=True)
prompt = ' a cubic planet with atmoshere as seen from low orbit, each side of the cubic planet is ocuppied by an ocean, oceans have islands, but no continents, atmoshere of the planet has usual sperical shape, corners of the cube are above the atmoshere, but edges largely are covered by the atomosphere, there are cyclones in the atmoshere, the photo is made from low-orbit, famous sci-fi illustration'
pipe.setup(width=512, height=512, guidance_scale=7, scheduler="DPMSolverMultistepScheduler", steps=5)
seed = 49045438434843
params = dict(prompt=prompt,
negative_prompt="spherical",
generator=torch.cuda.manual_seed(seed))
image = pipe.gen(params)
image.save("cube_test_lpw.png")
params = dict(prompt=prompt + " , best quality, famous photo",
negative_prompt="spherical",
generator=torch.cuda.manual_seed(seed))
image1 = pipe.gen(params)
image.save("cube_test_lpw1.png")
diff = self.compute_diff(image1, image)
# check that difference is large
self.assertGreater(diff, 1000)

@unittest.skipIf(not can_run_lpw(), "can't run on tiny version of SD")
def test_lpw_turned_off(self):
"""
Check that last part of long prompt don't affect the generation with lpw turned off
"""
pipe = Prompt2ImPipe(self.get_model(), model_type=self.model_type(), lpw=False)
prompt = ' a cubic planet with atmoshere as seen from low orbit, each side of the cubic planet is ocuppied by an ocean, oceans have islands, but no continents, atmoshere of the planet has usual sperical shape, corners of the cube are above the atmoshere, but edges largely are covered by the atomosphere, there are cyclones in the atmoshere, the photo is made from low-orbit, famous sci-fi illustration'
pipe.setup(width=512, height=512, guidance_scale=7, scheduler="DPMSolverMultistepScheduler", steps=5)
seed = 49045438434843
params = dict(prompt=prompt,
negative_prompt="spherical",
generator=torch.cuda.manual_seed(seed))
image = pipe.gen(params)
image.save("cube_test_no_lpw.png")
params = dict(prompt=prompt + " , best quality, famous photo",
negative_prompt="spherical",
generator=torch.cuda.manual_seed(seed))
image1 = pipe.gen(params)
image.save("cube_test_no_lpw1.png")
diff = self.compute_diff(image1, image)
# check that difference is large
self.assertLess(diff, 1)


class TestSDXL(MyTestCase):

def get_model(self):
models_dir = os.environ.get('METAFUSION_MODELS_DIR', None)
if models_dir is not None:
return models_dir + '/SDXL/stable-diffusion-xl-base-1.0'
return "hf-internal-testing/tiny-stable-diffusion-xl-pipe"


Expand Down

0 comments on commit df5157c

Please sign in to comment.