Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stable-diffusion example using panel #254

Open
wants to merge 63 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
c951be6
Add stable-diffusion example with panel widgets
sandhujasmine Jan 30, 2023
aa3f94d
rename to change - to _
sandhujasmine Feb 2, 2023
c81806e
delete .projectignore since its not needed
sandhujasmine Feb 2, 2023
1076779
fix name; remove deployment and unused parts
sandhujasmine Feb 2, 2023
a91564c
Merge branch 'main' into stable-diffusion
sandhujasmine Feb 2, 2023
90d37bc
Add min versions based on solved spec
sandhujasmine Feb 2, 2023
8c3daf1
Add lock file; define min package versions in spec
sandhujasmine Feb 8, 2023
881108c
Use cuda if available. Should allow CI to run notebook
sandhujasmine Feb 8, 2023
127a676
Updated as panel Viewer class (thanks @philippjfr!)
sandhujasmine Feb 10, 2023
9815d86
Removed metadata per review comments.
sandhujasmine Feb 10, 2023
c44692c
Quick test to see if updating conda fixes the libmamba CI issue
sandhujasmine Feb 13, 2023
2765d63
Merge branch 'stable-diffusion' of github.com:pyviz-topics/examples i…
sandhujasmine Feb 13, 2023
ea2ace6
add check for mps; fix app so copy/paste URL works to regenerate image
sandhujasmine Feb 15, 2023
dbe639a
Merge branch 'stable-diffusion' of github.com:pyviz-topics/examples i…
sandhujasmine Feb 15, 2023
aecb8e5
Revert "Quick test to see if updating conda fixes the libmamba CI issue"
sandhujasmine Feb 15, 2023
4574430
Merge branch 'main' into stable-diffusion
sandhujasmine Feb 15, 2023
71c1a19
Fix torch API calls to check MPS availability
sandhujasmine Feb 16, 2023
5c99f00
Make computation more efficient
sandhujasmine Feb 19, 2023
01e147d
remove pins for now so conda solve works
sandhujasmine Feb 19, 2023
2428a39
generate on prompt enter as well
sandhujasmine Feb 20, 2023
891b997
make gallery a grid; add diffusers logo
sandhujasmine Feb 27, 2023
7d30c94
Define env_spec for stable-diffusion-m1 so it works on osx & linux
sandhujasmine Feb 27, 2023
e5cd512
cropped diffusers logo
sandhujasmine Feb 28, 2023
327d761
lock file incorrect for osx; remove it.
sandhujasmine Feb 28, 2023
3bd3608
Updated the documentation
sandhujasmine Feb 28, 2023
7d1305d
remove future improvements section
sandhujasmine Feb 28, 2023
78bcb9a
remove empty cell and add some more documentation
sandhujasmine Feb 28, 2023
5dbcaab
Clean up
sandhujasmine Feb 28, 2023
bbf2bb1
Fixed so it works on linux with cpu only
sandhujasmine Mar 1, 2023
54ed516
yet another lock file which seems to work on both (linux-64, osx)
sandhujasmine Mar 1, 2023
b22bd84
limit history of thumbnails to 15
sandhujasmine Mar 2, 2023
c355008
Load old image from memory; invoke callback on prompt or button; clean
sandhujasmine Mar 6, 2023
94457c0
Add label to random seed
sandhujasmine Mar 7, 2023
73c0135
Add commands to run on M1
sandhujasmine Mar 7, 2023
a9965e2
updated docs
sandhujasmine Mar 8, 2023
fff4a7a
Add description for the commands
sandhujasmine Mar 9, 2023
883ba6c
Update thumbnail
sandhujasmine Mar 9, 2023
0b6f286
Merge branch 'main' into stable-diffusion
maximlt Mar 12, 2023
565d0b0
Updated and cleaned up
jbednar Apr 26, 2023
7684f01
Merge branch 'main' into stable-diffusion
maximlt Apr 30, 2023
a53bf2b
declare no data ingestion
maximlt Apr 30, 2023
9782015
resize image
maximlt Apr 30, 2023
eea71a6
clean up project file
maximlt Apr 30, 2023
8b95eb2
linting fixes
maximlt Apr 30, 2023
9df5d24
Merge branch 'main' into stable-diffusion
maximlt Apr 30, 2023
e5ad566
temporary skip tests
maximlt Apr 30, 2023
7495758
allow skip_test
maximlt Apr 30, 2023
261ea35
run on macos
maximlt Apr 30, 2023
a825edb
point to the general environment file
maximlt Apr 30, 2023
19f3ed4
Revert "point to the general environment file"
maximlt May 1, 2023
5a4f3e0
Revert "run on macos"
maximlt May 1, 2023
ccfe2bc
Revert "allow skip_test"
maximlt May 1, 2023
edf9a33
Revert "temporary skip tests"
maximlt May 1, 2023
4944738
Merge branch 'main' into stable-diffusion
maximlt May 1, 2023
f681c77
test and build on macos-latest
maximlt May 1, 2023
f1214d6
[relaunch build]
maximlt May 1, 2023
83a18ec
[debug]
maximlt May 2, 2023
5db5619
checkout original branch
maximlt May 2, 2023
2697e32
undebug
maximlt May 2, 2023
c1bf92b
Merge branch 'main' into stable-diffusion
maximlt May 2, 2023
d0cb627
Fix link and add a link to huggingface article on memory
sandhujasmine May 10, 2023
6faab85
Merge branch 'main' into stable-diffusion
maximlt May 16, 2023
6972c11
Merge branch 'main' into stable-diffusion
maximlt Jun 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 51 additions & 48 deletions stable_diffusion/stable_diffusion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,20 @@
"\n",
"@contextmanager\n",
"def exec_time(description=\"Task\"):\n",
" \"\"\"Context manager to measure execution time and print it to the console\"\"\" \n",
" \"\"\"Context manager to measure execution time and print it to the console\"\"\"\n",
" st = time.perf_counter()\n",
" yield \n",
" yield\n",
" print(f\"{description}: {time.perf_counter() - st:.2f} sec\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invoking Stable Diffusion on a prompt\n",
"\n",
"The `init_model` function below will first look in the default cache location used by huggingface to find downloaded pretrained models. If these haven't been downloaded yet, it will first download the models. On subsequent restarts of the app, it will load the models from the local disk cache. \n",
"The `init_model` function below will first look in the default cache location used by huggingface to find downloaded pretrained models. If these haven't been downloaded yet, it will first download the models. On subsequent restarts of the app, it will load the models from the local disk cache.\n",
"\n",
"<p>\n",
"<details><summary><u>(Optional: how to download models manually)</u></summary>\n",
Expand All @@ -94,7 +95,7 @@
"The initial page load takes an extra ~10 sec or so (on a Quadro RTX 8000) and allocates the GPU memory required to load the pipeline in memory. Subsequent visitors get this pipeline from panel's cache. The memory overhead per visitor is then the amount needed to generate the image text prompt.\n",
"\n",
"<details><summary><p><br><u>(Optional: performance details)</u></summary>\n",
" \n",
"\n",
"Sample output from `nvidia-smi` with memory usage information, running on a machine with Quadro RTX 8000 GPUs, after both models load:\n",
"\n",
"<pre>\n",
Expand Down Expand Up @@ -138,31 +139,33 @@
"\n",
"def init_model(model, cuda, mps, local_files_only=True):\n",
" print(f\"Init model: {model}\")\n",
" pipe = StableDiffusionPipeline.from_pretrained(model,\n",
" pipe = StableDiffusionPipeline.from_pretrained(\n",
" model,\n",
" torch_dtype=torch.float16 if cuda or mps else None,\n",
" local_files_only=local_files_only) \n",
" local_files_only=local_files_only\n",
" )\n",
"\n",
" # let torch choose the GPU if more than 1 is available\n",
" if cuda:\n",
" pipe.to(f\"cuda\")\n",
" pipe.to(\"cuda\")\n",
" elif mps:\n",
" pipe.to(f\"mps\")\n",
" pipe.to(\"mps\")\n",
" pipe.enable_attention_slicing()\n",
" return pipe \n",
" return pipe\n",
"\n",
"\n",
"if 'pipelines' in pn.state.cache:\n",
" print(f\"load from cache\")\n",
" print(\"load from cache\")\n",
" pipelines = pn.state.cache['pipelines']\n",
" pseudo_rand_gen = pn.state.cache['pseudo_rand_gen']\n",
"else:\n",
" cuda = torch.cuda.is_available()\n",
" mps = torch.backends.mps.is_available()\n",
" device = 'cuda' if cuda else 'cpu'\n",
"\n",
" models = ['runwayml/stable-diffusion-v1-5', \n",
" models = ['runwayml/stable-diffusion-v1-5',\n",
" 'CompVis/stable-diffusion-v1-4']\n",
" \n",
"\n",
" pseudo_rand_gen = torch.Generator(device=device)\n",
" with exec_time(\"Load models\"):\n",
" pipelines = dict()\n",
Expand All @@ -172,10 +175,10 @@
" pipelines[m] = init_model(m, cuda, mps)\n",
" except OSError:\n",
" pipelines[m] = init_model(m, cuda, mps, local_files_only=False)\n",
" \n",
"\n",
" pn.state.cache['pipelines'] = pipelines\n",
" pn.state.cache['pseudo_rand_gen'] = pseudo_rand_gen\n",
" print(f\"Save to cache\")\n",
" print(\"Save to cache\")\n",
"\n",
"\n",
"default_model = next(iter(pipelines))"
Expand Down Expand Up @@ -218,45 +221,46 @@
"class StableDiffusion(param.Parameterized):\n",
" prompt = param.String(doc=\"\"\"\n",
" Text describing the image you wish to generate\"\"\")\n",
" \n",
"\n",
" negative_prompt = param.String(doc=\"\"\"\n",
" Text describing what _not_ to include in the image (for refining results)\"\"\")\n",
" \n",
"\n",
" model = param.Selector(objects=list(pipelines), default=default_model, doc=\"\"\"\n",
" A pre-trained model to be used for inference\"\"\")\n",
" \n",
"\n",
" _size_range = tuple(448 + i*2**6 for i in range(10))\n",
" width = param.Selector(_size_range, default=_size_range[1], doc=\"\"\"\n",
" Width (in pixels) of the images to generate\"\"\")\n",
" \n",
"\n",
" height = param.Selector(_size_range, default=_size_range[1], doc=\"\"\"\n",
" Height (in pixels) of the images to generate\"\"\")\n",
" \n",
"\n",
" guidance_scale = param.Number(bounds=(5, 10), softbounds=(7, 8.5), step=0.1, default=7.5, doc=\"\"\"\n",
" How closely the model should try to match the prompt, at the \n",
" How closely the model should try to match the prompt, at the\n",
" potential expense of image quality or diversity.\n",
" Also known as CFG (Classifier-free guidance scale).\"\"\")\n",
" \n",
"\n",
" num_steps = param.Integer(label='# of steps', bounds=(10, 75), default=30, doc=\"\"\"\n",
" How many denoising steps to take. \n",
" How many denoising steps to take.\n",
" More steps takes longer but gives a more-refined image.\"\"\")\n",
"\n",
" seed = param.Integer(label='Random seed', default=1,\n",
" bounds=random_int_range, step=10, precedence=1, doc=\"\"\"\n",
" Seed controlling the noise values generated.\"\"\")\n",
" \n",
"\n",
" generate = param.Event(precedence=1)\n",
" \n",
" param.depends(\"generate\")\n",
"\n",
" @param.depends(\"generate\")\n",
" def __call__(self, **params):\n",
" p = param.ParamOverrides(self, params)\n",
" pipe = pipelines[p.model]\n",
" \n",
"\n",
" res = pipe(num_inference_steps=p.num_steps, generator=pseudo_rand_gen.manual_seed(p.seed),\n",
" **{k:p[k] for k in ['prompt', 'negative_prompt', 'guidance_scale', 'height', 'width']})\n",
"\n",
" return res.images[0]\n",
" \n",
"\n",
"\n",
"sd = StableDiffusion()"
]
},
Expand Down Expand Up @@ -331,7 +335,7 @@
"\n",
"class Gallery(ListLike, ReactiveHTML):\n",
" \"\"\"Collection of thumbnails that, when selected, restore the associated image and its parameters\"\"\"\n",
" \n",
"\n",
" objects = param.List(item_type=Viewable)\n",
" current = param.Integer(default=None)\n",
" margin = param.Integer(0)\n",
Expand All @@ -343,7 +347,7 @@
" {% endfor %}\n",
" </div>\n",
" \"\"\"\n",
" \n",
"\n",
" _scripts = {\n",
" 'click': \"\"\"\n",
" const id = event.target.parentNode.parentNode.parentNode.id;\n",
Expand Down Expand Up @@ -372,13 +376,13 @@
" model = param.Parameter(StableDiffusion())\n",
" gallery = param.ClassSelector(class_=Gallery, default=Gallery(min_height=100), precedence=-1)\n",
" generate_image = param.Event(precedence=1)\n",
" \n",
"\n",
" def __init__(self, **params):\n",
" self.history = deque(maxlen=15)\n",
" super().__init__(**params)\n",
" self.gallery.param.watch(self._restore_history, 'current')\n",
" self._restore = False\n",
" self._image_container = pn.pane.PNG(style={'border': '1px solid black'}, \n",
" self._image_container = pn.pane.PNG(style={'border': '1px solid black'},\n",
" height=self.model.height,\n",
" width=self.model.width)\n",
" # ensure seed always starts out being set\n",
Expand All @@ -395,14 +399,14 @@
" try:\n",
" setattr(self, attr, value)\n",
" yield\n",
" setattr(self, attr, not(value))\n",
" setattr(self, attr, not value)\n",
" except Exception as ex:\n",
" setattr(self, attr, init_state)\n",
" raise ex\n",
"\n",
" def _update_query_params(self):\n",
" \"\"\"\n",
" Remove all params first since update_query will only update the non-default values. \n",
" Remove all params first since update_query will only update the non-default values.\n",
" If the current URL has non-default values, those will be incorrect unless it is first cleared\n",
" \"\"\"\n",
" pn.state.location.search = ''\n",
Expand Down Expand Up @@ -432,18 +436,18 @@
" self._update_query_params()\n",
" # Also update the seed so `generate_image` doesn't recreate same image\n",
" self.model.seed = random.randint(*self.model.param.seed.bounds)\n",
" \n",
"\n",
" @property\n",
" def _state(self):\n",
" return {k: v for k, v in self.model.param.values().items() if k != 'name'}\n",
" \n",
"\n",
" @property\n",
" def _url_params(self):\n",
" # only capture state that deviates from default\n",
" state = {key: getattr(self.model, key) for key, val in self.model.param.defaults().items()\n",
" if key != 'name' and getattr(self.model, key) != val}\n",
" return state\n",
" \n",
"\n",
" def _on_load(self):\n",
" if pn.state.location and pn.state.location.query_params:\n",
" self.model.param.update(pn.state.location.query_params)\n",
Expand All @@ -463,12 +467,11 @@
"\n",
" with exec_time(f\"Generate {self.model.prompt}\"):\n",
" image = self.model()\n",
" image_seed = self.model.seed\n",
"\n",
" if len(self.gallery) == self.history.maxlen:\n",
" # Oldest element from history will be dropped\n",
" self.gallery.remove(self.gallery[0])\n",
" \n",
"\n",
" self.gallery.append(pn.pane.PNG(image.resize((100, 100))))\n",
" # store full state in history\n",
" self.history.append((self._state, image))\n",
Expand All @@ -481,27 +484,27 @@
" def _sidebar_widgets(self):\n",
" return pn.Param(self.model.param, widgets = {\n",
" 'height': pn.widgets.DiscreteSlider,\n",
" 'width' : pn.widgets.DiscreteSlider,\n",
" 'width': pn.widgets.DiscreteSlider,\n",
" 'guidance_scale': {'formatter': PrintfTickFormatter(format='%.1f')},\n",
" 'seed': pn.widgets.IntInput,\n",
" 'prompt': {'visible': False},\n",
" 'negative_prompt': {'visible': False},\n",
" 'generate': {'visible': False}})\n",
"\n",
" def _main_panel(self):\n",
" return pn.Column(pn.Row(pn.Column(self.model.param.prompt, self.model.param.negative_prompt, \n",
" return pn.Column(pn.Row(pn.Column(self.model.param.prompt, self.model.param.negative_prompt,\n",
" sizing_mode='stretch_width'),\n",
" pn.Param(self.param.generate_image, \n",
" widgets={'generate_image': {'button_type': 'success', \n",
" pn.Param(self.param.generate_image,\n",
" widgets={'generate_image': {'button_type': 'success',\n",
" 'height': 110, 'width': 30}})),\n",
" pn.Row(pn.panel(self._image_container, loading_indicator=True), self.gallery))\n",
" \n",
"\n",
" def __panel__(self):\n",
" return pn.Row(\n",
" pn.Column(self._sidebar_widgets()),\n",
" pn.Column(self._main_panel(), sizing_mode='stretch_width'))\n",
"\n",
" \n",
"\n",
"sdui = ModelUI(name='Stable Diffusion with Panel UI')\n",
"\n",
"sdui"
Expand All @@ -521,16 +524,16 @@
"outputs": [],
"source": [
"logo_pn = \"\"\"<a href=\"http://panel.pyviz.org\">\n",
" <img src=\"https://panel.pyviz.org/_static/logo_stacked.png\" \n",
" <img src=\"https://panel.pyviz.org/_static/logo_stacked.png\"\n",
" width=108 height=91 align=\"left\" margin=10px>\"\"\"\n",
"\n",
"logo_diffusers = \"\"\"<a href=\"https://huggingface.co/docs/diffusers/index\">\n",
" <img src=\"./thumbnails/diffusers_logo.png\" \n",
" <img src=\"./thumbnails/diffusers_logo.png\"\n",
" width=198 height=102 align=\"left\" margin=10px>\"\"\"\n",
"\n",
"desc = \"\"\"\n",
" The <a href=\"http://panel.pyviz.org\">Panel</a> library from \n",
" <a href=\"https://holoviz.org/\">HoloViz</a> \n",
" The <a href=\"http://panel.pyviz.org\">Panel</a> library from\n",
" <a href=\"https://holoviz.org/\">HoloViz</a>\n",
" lets you make widget-controlled apps. This Panel app lets you use the\n",
" <a href=\"https://huggingface.co/docs/diffusers/index\">diffusers</a> library to\n",
" generate images from pretrained diffusion models.\"\"\"\n",
Expand Down