Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cpu offload setup #85

Merged
merged 5 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions examples/work2im_temp_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Example of using Hyper-SD with metafusion


# https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0
requires this model to be in models-sd/SDXL/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors

and https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SD15-4steps-lora.safetensors to be in

models-sd/Lora/Hyper-SDXL-4steps-lora.safetensors
"""

import time
import random
import PIL.Image
import torch
from multigen.prompting import Cfgen
from multigen.sessions import GenSession
from multigen.pipes import Prompt2ImPipe

from multigen.worker import ServiceThread

from multigen.log import setup_logger
setup_logger()

nprompt = "jpeg artifacts, blur, distortion, watermark, signature, extra fingers, fewer fingers, lowres, bad hands, duplicate heads, bad anatomy, bad crop"

prompt = "Close-up portrait of a woman wearing suit posing with black background, rim lighting, octane, unreal"
seed = 383947828373273


cfg_file = 'config.yaml'


def random_session(pipe, model):
id = str(random.randint(0, 1024*1024*1024*4-1))
session = dict()
session["images"] = []
session["user"] = id
session["pipe"] = pipe
session["model"] = model
return { "session_id": id, 'session': session }


worker = ServiceThread(cfg_file)
worker.start()

pipe = "prompt2image"
pipe = "image2image"
model = list(worker.models['base'].keys())[-1]

count = 5
c = 0
def on_new_image(*args, **kwargs):
print(args, kwargs)
print('on new image')
global c
c += 1

def on_finish(*args, **kwargs):
print('finish')
print(args, kwargs)


# worker.queue_gen(session_id=sess_id,
# images=None,
# prompt=prompt, pipe='Prompt2ImPipe',
# nprompt=nprompt,
# image_callback=on_new_image,
# # lpw=False,
# width=1024, height=1024, steps=4,
# guidance_scale=0,
# count=count,
# seeds=[seed + i for i in range(count)],
# )

generator = torch.Generator().manual_seed(92)
init_image = PIL.Image.open('cr.png')
random_sess = random_session(pipe, model)
#worker.queue_gen(
# gen_dir='/tmp/img1',
# image_callback=on_new_image,
# prompt=prompt, count=count,
# images=[init_image], generator=generator,
# num_inference_steps=50, strength=0.82, guidance_scale=3.5,
# **random_sess)


pipe = "inpaint"
prompt = "a football player holding a gun, pointing it towards viewer"

mask = PIL.Image.open('select.png')
random_sess = random_session(pipe, model)
worker.queue_gen(
gen_dir='/tmp/img1_inp',
finish_callback=on_finish,
image_callback=on_new_image,
prompt=prompt, count=count,
image=init_image, mask=mask, generator=generator,
guidance_scale=7,
strength=0.9, steps=30,
**random_sess)



while count != c:
time.sleep(1)
worker.stop()
2 changes: 1 addition & 1 deletion multigen/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def setup_logger(path='log_file.log'):

ch.addFilter(thread_id_filter)
fh.addFilter(thread_id_filter)
formatter = logging.Formatter('%(asctime)s - %(thread)d - %(levelname)s - %(message)s')
formatter = logging.Formatter('%(asctime)s - %(thread)d - %(levelname)s - %(funcName)20s() - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)

Expand Down
12 changes: 6 additions & 6 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(self, model_id: str,
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe_passed = pipe is not None
self.pipe = pipe
self._scheduler = None
self._hypernets = []
Expand All @@ -126,8 +125,7 @@ def __init__(self, model_id: str,
if mt != model_type:
raise RuntimeError(f"passed model type {self.model_type} doesn't match actual type {mt}")

if not pipe_passed:
self._initialize_pipe(device, offload_device)
self._initialize_pipe(device, offload_device)
self.lpw = lpw
self._loras = []

Expand Down Expand Up @@ -164,10 +162,11 @@ def _initialize_pipe(self, device, offload_device):
self.pipe.vae.enable_tiling()
# --- the best one and seems to be enough ---
# self.pipe.enable_sequential_cpu_offload()
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
if self.model_type == ModelType.FLUX:
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
pass
else:
try:
import xformers
Expand Down Expand Up @@ -409,6 +408,7 @@ def gen(self, inputs: dict):
generated image
"""
kwargs = self.prepare_inputs(inputs)
logging.debug("Prompt2ImPipe.gen calling pipe")
image = self.pipe(**kwargs).images[0]
return image

Expand Down
29 changes: 20 additions & 9 deletions multigen/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import concurrent
from queue import Empty
import PIL

from .worker_base import ServiceThreadBase
from .prompting import Cfgen
Expand Down Expand Up @@ -62,7 +63,7 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None):
cls = pipe_class._classflux
if device.type == 'cuda':
offload_device = device.index
device = torch.device('cpu', 0)
device = torch.device('cpu')
else:
cls = pipe_class._class
pipeline = self._loader.load_pipeline(cls, model_id, torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -110,8 +111,10 @@ def _update(sess, job, gs):
device = None
# keep the job in the queue until complete
try:
sess = data.get('session', None)
session_id = data["session_id"]
sess = self.sessions[session_id]
if sess is None:
sess = self.sessions[session_id]
sess['status'] ='running'
self.logger.info("GENERATING: " + str(data))
if 'start_callback' in data:
Expand All @@ -132,7 +135,9 @@ def _update(sess, job, gs):
raise RuntimeError(f"unexpected model type {mt}")
pipe = self.get_pipeline(pipe_name, model_id, model_type, cnet=data.get('cnet', None))
device = pipe.pipe.device
offload_device = pipe.offload_gpu_id
offload_device = None
if hasattr(pipe, 'offload_gpu_id'):
offload_device = pipe.offload_gpu_id
self.logger.debug(f'running job on {device} offload {offload_device}')
if device.type in ['cuda', 'meta']:
with self._lock:
Expand All @@ -143,26 +148,32 @@ def _update(sess, job, gs):
class_name = str(pipe.__class__)
self.logger.debug(f'got pipeline {class_name}')

images = data['images']
if 'MaskedIm2ImPipe' in class_name:
images = data.get('images', None)
if images and 'MaskedIm2ImPipe' in class_name:
pipe.setup(**data, original_image=str(images[0]),
image_painted=str(images[1]))
elif any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]):
pipe.setup(**data, fimage=str(images[0]))
elif images and any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]):
if isinstance(images[0], PIL.Image.Image):
pipe.setup(**data, fimage=None, image=images[0])
else:
pipe.setup(**data, fimage=str(images[0]))
else:
pipe.setup(**data)
# TODO: add negative prompt to parameters
nprompt_default = "jpeg artifacts, blur, distortion, watermark, signature, extra fingers, fewer fingers, lowres, nude, bad hands, duplicate heads, bad anatomy, bad crop"
nprompt = data.get('nprompt', nprompt_default)
seeds = data.get('seeds', None)
self.logger.debug(f"offload_device {pipe.offload_gpu_id}")
gs = GenSession(self.get_image_pathname(data["session_id"], None),
directory = data.get('gen_dir', None)
if directory is None:
directory = self.get_image_pathname(data["session_id"], None)
gs = GenSession(directory,
pipe, Cfgen(data["prompt"], nprompt, seeds=seeds))
gs.gen_sess(add_count = data["count"],
callback = lambda: _update(sess, data, gs))
if 'finish_callback' in data:
data['finish_callback']()
except (RuntimeError, TypeError, NotImplementedError) as e:
except (RuntimeError, TypeError, NotImplementedError, OSError) as e:
self.logger.error("error in generation", exc_info=e)
if hasattr(pipe.pipe, '_offload_gpu_id'):
self.logger.error(f"offload_device {pipe.pipe._offload_gpu_id}")
Expand Down
7 changes: 2 additions & 5 deletions multigen/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml
from pathlib import Path
import logging
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe, Im2ImPipe
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe, Im2ImPipe, InpaintingPipe
from .loader import Loader


Expand Down Expand Up @@ -96,10 +96,7 @@ def queue_gen(self, **args):
self.logger.info("REQUESTED FOR QUEUE: " + str(args))
with self._lock:
if args["session_id"] not in self.sessions:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep these checks at least for logs. Or are they not relevant anymore at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old api would work just fine. Relevance depends on the clients of our api

return { "error": "Session is not open" }
# for q in self.queue:
# if q["session_id"] == args["session_id"]:
# return { "error": "The job for this session already exists" }
self.logger.debug(str(args["session_id"]) + ' is not found in open sessions')
a = {**args}
a["count"] = int(a["count"])
if a["count"] <= 0:
Expand Down
28 changes: 2 additions & 26 deletions tests/pipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_loader(self):
# load inpainting pipe
cls = classes[model_type]
pipeline = loader.load_pipeline(cls, model_id, **self.device_args)
inpaint = MaskedIm2ImPipe(model_id, pipe=pipeline)
inpaint = MaskedIm2ImPipe(model_id, pipe=pipeline, **self.device_args)


prompt_classes = self.get_cls_by_type(Prompt2ImPipe)
Expand All @@ -138,7 +138,7 @@ def test_loader(self):
device = torch.device('cpu', 0)
device_args['device'] = device
pipeline = loader.load_pipeline(cls, model_id, **device_args)
prompt2image = Prompt2ImPipe(model_id, pipe=pipeline)
prompt2image = Prompt2ImPipe(model_id, pipe=pipeline, **device_args)
prompt2image.setup(width=512, height=512, scheduler=self.schedulers[0], clip_skip=2, steps=5)
if device.type == 'cuda':
self.assertEqual(inpaint.pipe.unet.conv_out.weight.data_ptr(),
Expand Down Expand Up @@ -323,30 +323,6 @@ def get_model(self):
def test_lpw_turned_off(self):
pass

def est_basic_txt2im(self):
model = self.get_model()
device = torch.device('cpu', 0)
# create pipe
offload = 0 if torch.cuda.is_available() else None
pipe = Prompt2ImPipe(model, pipe=self._pipeline,
model_type=self.model_type(),
device=device, offload_device=offload)
pipe.setup(width=512, height=512, guidance_scale=7, scheduler="FlowMatchEulerDiscreteScheduler", steps=5)
seed = 49045438434843
params = dict(prompt="a cube planet, cube-shaped, space photo, masterpiece",
negative_prompt="spherical",
generator=torch.Generator(device).manual_seed(seed))
image = pipe.gen(params)
image.save("cube_test.png")

# generate with different seed
params['generator'] = torch.Generator(device).manual_seed(seed + 1)
image_ddim = pipe.gen(params)
image_ddim.save("cube_test2_dimm.png")
diff = self.compute_diff(image_ddim, image)
# check that difference is large
self.assertGreater(diff, 1000)


if __name__ == '__main__':
setup_logger('test_pipe.log')
Expand Down
Loading