Skip to content

Commit

Permalink
Merge pull request #85 from noskill/vbr1
Browse files Browse the repository at this point in the history
fix cpu offload setup
  • Loading branch information
noskill authored Oct 24, 2024
2 parents 8b44048 + d58754f commit a56ae2e
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 47 deletions.
108 changes: 108 additions & 0 deletions examples/work2im_temp_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Example of using Hyper-SD with metafusion
# https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0
requires this model to be in models-sd/SDXL/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors
and https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SD15-4steps-lora.safetensors to be in
models-sd/Lora/Hyper-SDXL-4steps-lora.safetensors
"""

import time
import random
import PIL.Image
import torch
from multigen.prompting import Cfgen
from multigen.sessions import GenSession
from multigen.pipes import Prompt2ImPipe

from multigen.worker import ServiceThread

from multigen.log import setup_logger
setup_logger()

nprompt = "jpeg artifacts, blur, distortion, watermark, signature, extra fingers, fewer fingers, lowres, bad hands, duplicate heads, bad anatomy, bad crop"

prompt = "Close-up portrait of a woman wearing suit posing with black background, rim lighting, octane, unreal"
seed = 383947828373273


cfg_file = 'config.yaml'


def random_session(pipe, model):
id = str(random.randint(0, 1024*1024*1024*4-1))
session = dict()
session["images"] = []
session["user"] = id
session["pipe"] = pipe
session["model"] = model
return { "session_id": id, 'session': session }


worker = ServiceThread(cfg_file)
worker.start()

pipe = "prompt2image"
pipe = "image2image"
model = list(worker.models['base'].keys())[-1]

count = 5
c = 0
def on_new_image(*args, **kwargs):
print(args, kwargs)
print('on new image')
global c
c += 1

def on_finish(*args, **kwargs):
print('finish')
print(args, kwargs)


# worker.queue_gen(session_id=sess_id,
# images=None,
# prompt=prompt, pipe='Prompt2ImPipe',
# nprompt=nprompt,
# image_callback=on_new_image,
# # lpw=False,
# width=1024, height=1024, steps=4,
# guidance_scale=0,
# count=count,
# seeds=[seed + i for i in range(count)],
# )

generator = torch.Generator().manual_seed(92)
init_image = PIL.Image.open('cr.png')
random_sess = random_session(pipe, model)
#worker.queue_gen(
# gen_dir='/tmp/img1',
# image_callback=on_new_image,
# prompt=prompt, count=count,
# images=[init_image], generator=generator,
# num_inference_steps=50, strength=0.82, guidance_scale=3.5,
# **random_sess)


pipe = "inpaint"
prompt = "a football player holding a gun, pointing it towards viewer"

mask = PIL.Image.open('select.png')
random_sess = random_session(pipe, model)
worker.queue_gen(
gen_dir='/tmp/img1_inp',
finish_callback=on_finish,
image_callback=on_new_image,
prompt=prompt, count=count,
image=init_image, mask=mask, generator=generator,
guidance_scale=7,
strength=0.9, steps=30,
**random_sess)



while count != c:
time.sleep(1)
worker.stop()
2 changes: 1 addition & 1 deletion multigen/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def setup_logger(path='log_file.log'):

ch.addFilter(thread_id_filter)
fh.addFilter(thread_id_filter)
formatter = logging.Formatter('%(asctime)s - %(thread)d - %(levelname)s - %(message)s')
formatter = logging.Formatter('%(asctime)s - %(thread)d - %(levelname)s - %(funcName)20s() - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)

Expand Down
12 changes: 6 additions & 6 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(self, model_id: str,
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe_passed = pipe is not None
self.pipe = pipe
self._scheduler = None
self._hypernets = []
Expand All @@ -126,8 +125,7 @@ def __init__(self, model_id: str,
if mt != model_type:
raise RuntimeError(f"passed model type {self.model_type} doesn't match actual type {mt}")

if not pipe_passed:
self._initialize_pipe(device, offload_device)
self._initialize_pipe(device, offload_device)
self.lpw = lpw
self._loras = []

Expand Down Expand Up @@ -164,10 +162,11 @@ def _initialize_pipe(self, device, offload_device):
self.pipe.vae.enable_tiling()
# --- the best one and seems to be enough ---
# self.pipe.enable_sequential_cpu_offload()
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
if self.model_type == ModelType.FLUX:
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
pass
else:
try:
import xformers
Expand Down Expand Up @@ -409,6 +408,7 @@ def gen(self, inputs: dict):
generated image
"""
kwargs = self.prepare_inputs(inputs)
logging.debug("Prompt2ImPipe.gen calling pipe")
image = self.pipe(**kwargs).images[0]
return image

Expand Down
29 changes: 20 additions & 9 deletions multigen/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import concurrent
from queue import Empty
import PIL

from .worker_base import ServiceThreadBase
from .prompting import Cfgen
Expand Down Expand Up @@ -62,7 +63,7 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None):
cls = pipe_class._classflux
if device.type == 'cuda':
offload_device = device.index
device = torch.device('cpu', 0)
device = torch.device('cpu')
else:
cls = pipe_class._class
pipeline = self._loader.load_pipeline(cls, model_id, torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -110,8 +111,10 @@ def _update(sess, job, gs):
device = None
# keep the job in the queue until complete
try:
sess = data.get('session', None)
session_id = data["session_id"]
sess = self.sessions[session_id]
if sess is None:
sess = self.sessions[session_id]
sess['status'] ='running'
self.logger.info("GENERATING: " + str(data))
if 'start_callback' in data:
Expand All @@ -132,7 +135,9 @@ def _update(sess, job, gs):
raise RuntimeError(f"unexpected model type {mt}")
pipe = self.get_pipeline(pipe_name, model_id, model_type, cnet=data.get('cnet', None))
device = pipe.pipe.device
offload_device = pipe.offload_gpu_id
offload_device = None
if hasattr(pipe, 'offload_gpu_id'):
offload_device = pipe.offload_gpu_id
self.logger.debug(f'running job on {device} offload {offload_device}')
if device.type in ['cuda', 'meta']:
with self._lock:
Expand All @@ -143,26 +148,32 @@ def _update(sess, job, gs):
class_name = str(pipe.__class__)
self.logger.debug(f'got pipeline {class_name}')

images = data['images']
if 'MaskedIm2ImPipe' in class_name:
images = data.get('images', None)
if images and 'MaskedIm2ImPipe' in class_name:
pipe.setup(**data, original_image=str(images[0]),
image_painted=str(images[1]))
elif any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]):
pipe.setup(**data, fimage=str(images[0]))
elif images and any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]):
if isinstance(images[0], PIL.Image.Image):
pipe.setup(**data, fimage=None, image=images[0])
else:
pipe.setup(**data, fimage=str(images[0]))
else:
pipe.setup(**data)
# TODO: add negative prompt to parameters
nprompt_default = "jpeg artifacts, blur, distortion, watermark, signature, extra fingers, fewer fingers, lowres, nude, bad hands, duplicate heads, bad anatomy, bad crop"
nprompt = data.get('nprompt', nprompt_default)
seeds = data.get('seeds', None)
self.logger.debug(f"offload_device {pipe.offload_gpu_id}")
gs = GenSession(self.get_image_pathname(data["session_id"], None),
directory = data.get('gen_dir', None)
if directory is None:
directory = self.get_image_pathname(data["session_id"], None)
gs = GenSession(directory,
pipe, Cfgen(data["prompt"], nprompt, seeds=seeds))
gs.gen_sess(add_count = data["count"],
callback = lambda: _update(sess, data, gs))
if 'finish_callback' in data:
data['finish_callback']()
except (RuntimeError, TypeError, NotImplementedError) as e:
except (RuntimeError, TypeError, NotImplementedError, OSError) as e:
self.logger.error("error in generation", exc_info=e)
if hasattr(pipe.pipe, '_offload_gpu_id'):
self.logger.error(f"offload_device {pipe.pipe._offload_gpu_id}")
Expand Down
7 changes: 2 additions & 5 deletions multigen/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml
from pathlib import Path
import logging
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe, Im2ImPipe
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe, Im2ImPipe, InpaintingPipe
from .loader import Loader


Expand Down Expand Up @@ -96,10 +96,7 @@ def queue_gen(self, **args):
self.logger.info("REQUESTED FOR QUEUE: " + str(args))
with self._lock:
if args["session_id"] not in self.sessions:
return { "error": "Session is not open" }
# for q in self.queue:
# if q["session_id"] == args["session_id"]:
# return { "error": "The job for this session already exists" }
self.logger.debug(str(args["session_id"]) + ' is not found in open sessions')
a = {**args}
a["count"] = int(a["count"])
if a["count"] <= 0:
Expand Down
28 changes: 2 additions & 26 deletions tests/pipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_loader(self):
# load inpainting pipe
cls = classes[model_type]
pipeline = loader.load_pipeline(cls, model_id, **self.device_args)
inpaint = MaskedIm2ImPipe(model_id, pipe=pipeline)
inpaint = MaskedIm2ImPipe(model_id, pipe=pipeline, **self.device_args)


prompt_classes = self.get_cls_by_type(Prompt2ImPipe)
Expand All @@ -138,7 +138,7 @@ def test_loader(self):
device = torch.device('cpu', 0)
device_args['device'] = device
pipeline = loader.load_pipeline(cls, model_id, **device_args)
prompt2image = Prompt2ImPipe(model_id, pipe=pipeline)
prompt2image = Prompt2ImPipe(model_id, pipe=pipeline, **device_args)
prompt2image.setup(width=512, height=512, scheduler=self.schedulers[0], clip_skip=2, steps=5)
if device.type == 'cuda':
self.assertEqual(inpaint.pipe.unet.conv_out.weight.data_ptr(),
Expand Down Expand Up @@ -323,30 +323,6 @@ def get_model(self):
def test_lpw_turned_off(self):
pass

def est_basic_txt2im(self):
model = self.get_model()
device = torch.device('cpu', 0)
# create pipe
offload = 0 if torch.cuda.is_available() else None
pipe = Prompt2ImPipe(model, pipe=self._pipeline,
model_type=self.model_type(),
device=device, offload_device=offload)
pipe.setup(width=512, height=512, guidance_scale=7, scheduler="FlowMatchEulerDiscreteScheduler", steps=5)
seed = 49045438434843
params = dict(prompt="a cube planet, cube-shaped, space photo, masterpiece",
negative_prompt="spherical",
generator=torch.Generator(device).manual_seed(seed))
image = pipe.gen(params)
image.save("cube_test.png")

# generate with different seed
params['generator'] = torch.Generator(device).manual_seed(seed + 1)
image_ddim = pipe.gen(params)
image_ddim.save("cube_test2_dimm.png")
diff = self.compute_diff(image_ddim, image)
# check that difference is large
self.assertGreater(diff, 1000)


if __name__ == '__main__':
setup_logger('test_pipe.log')
Expand Down

0 comments on commit a56ae2e

Please sign in to comment.