Skip to content

Commit b3978ec

Browse files
committed
feat(runner): add diffusers pull script and upgrade to py3.11
1 parent 08830e1 commit b3978ec

File tree

9 files changed

+67
-242
lines changed

9 files changed

+67
-242
lines changed

.dockerignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ node_modules
33
*.sw*
44
venv
55
vendor
6+
**/.venv/*

Dockerfile.runner

+18-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
#syntax=docker/dockerfile:1.4
22

3-
ARG TAG=2024-11-21a-empty
3+
ARG TAG=main-small
4+
ARG UV_VERSION="0.5.4"
45

5-
FROM ghcr.io/astral-sh/uv:0.5.4 as uv
6+
FROM ghcr.io/astral-sh/uv:${UV_VERSION} AS uv
7+
FROM ghcr.io/astral-sh/uv:${UV_VERSION}-bookworm-slim AS diffusers-build-env
8+
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_PYTHON_INSTALL_DIR=/workspace/helix/runner/helix-diffusers/.python
9+
WORKDIR /workspace/helix/runner/helix-diffusers
10+
RUN --mount=type=cache,target=/root/.cache/uv \
11+
--mount=type=bind,source=runner/helix-diffusers/.python-version,target=.python-version \
12+
--mount=type=bind,source=runner/helix-diffusers/uv.lock,target=uv.lock \
13+
--mount=type=bind,source=runner/helix-diffusers/pyproject.toml,target=pyproject.toml \
14+
uv sync --frozen --no-install-project --no-dev
15+
ADD runner/helix-diffusers /workspace/helix/runner/helix-diffusers
16+
RUN --mount=type=cache,target=/root/.cache/uv \
17+
uv sync --frozen --no-dev
618

719
### BUILD
820

@@ -47,14 +59,10 @@ WORKDIR /workspace/helix
4759
# Copy runner directory from the repo
4860
COPY runner ./runner
4961

50-
# We need to set this environment variable so that uv knows where
51-
# the virtual environment is to install packages
52-
ENV UV_PROJECT_ENVIRONMENT=/workspace/helix/runner/helix-diffusers/venv
53-
54-
# Install the packages with uv using --mount=type=cache to cache the downloaded packages
55-
RUN --mount=type=cache,target=/root/.cache/uv \
56-
--mount=from=uv,source=/uv,target=/usr/bin/uv \
57-
cd /workspace/helix/runner/helix-diffusers && uv sync --no-dev
62+
# Copy the diffusers build environment including Python
63+
COPY --from=ghcr.io/astral-sh/uv:0.5.4 /uv /bin/uv
64+
COPY --from=diffusers-build-env /workspace/helix/runner/helix-diffusers /workspace/helix/runner/helix-diffusers
65+
ENV PATH="/workspace/helix/runner/helix-diffusers/.venv/bin:$PATH"
5866

5967
# Copy the cog wrapper, cog and cog-sdxl is installed in the base image, this is just the cog server
6068
COPY cog/helix_cog_wrapper.py /workspace/cog-sdxl/helix_cog_wrapper.py

api/pkg/model/models.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func ProcessModelName(
122122
}
123123
}
124124
case types.SessionTypeImage:
125-
return Model_Diffusers_SD35, nil
125+
return Model_Diffusers_SDTurbo, nil
126126
}
127127

128128
// shouldn't get here
@@ -157,6 +157,7 @@ const (
157157
Model_Axolotl_Mistral7b string = "mistralai/Mistral-7B-Instruct-v0.1"
158158
Model_Cog_SDXL string = "stabilityai/stable-diffusion-xl-base-1.0"
159159
Model_Diffusers_SD35 string = "stabilityai/stable-diffusion-3.5-medium"
160+
Model_Diffusers_SDTurbo string = "stabilityai/sd-turbo"
160161

161162
// We only need constants for _some_ ollama models that are hardcoded in
162163
// various places (backward compat). Other ones can be added dynamically now.
@@ -170,10 +171,10 @@ const (
170171
func GetDefaultDiffusersModels() ([]*DiffusersGenericImage, error) {
171172
return []*DiffusersGenericImage{
172173
{
173-
Id: Model_Diffusers_SD35,
174-
Name: "Stable Diffusion 3.5 Medium",
175-
Memory: GB * 21,
176-
Description: "Medium model, from Stability AI",
174+
Id: Model_Diffusers_SDTurbo,
175+
Name: "Stable Diffusion Turbo",
176+
Memory: GB * 5,
177+
Description: "Turbo model, from Stability AI",
177178
Hide: false,
178179
},
179180
}, nil

api/pkg/runner/diffusers_model_instance.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ func (i *DiffusersModelInstance) Start(ctx context.Context) error {
307307
if i.filter.Mode == types.SessionModeInference {
308308
cmd = exec.CommandContext(
309309
ctx,
310-
"bash", "/workspace/helix/runner/venv_command.sh",
310+
"uv", "run",
311311
"uvicorn", "main:app",
312312
"--host", "0.0.0.0",
313313
"--port", strconv.Itoa(i.port),
@@ -321,14 +321,9 @@ func (i *DiffusersModelInstance) Start(ctx context.Context) error {
321321
// Set the working directory to the runner dir (which makes relative path stuff easier)
322322
cmd.Dir = "/workspace/helix/runner/helix-diffusers"
323323

324-
// Inherit all the parent environment variables
325324
cmd.Env = append(cmd.Env,
326-
os.Environ()...,
327-
)
328-
329-
cmd.Env = append(cmd.Env,
330-
// Add the APP_FOLDER environment variable which is required by the old code
331-
fmt.Sprintf("APP_FOLDER=%s", path.Clean(cmd.Dir)),
325+
// Add the HF_TOKEN environment variable which is required by the diffusers library
326+
fmt.Sprintf("HF_TOKEN=hf_ISxQhTIkdWkfZgUFPNUwVtHrCpMiwOYPIEKEN=%s", os.Getenv("HF_TOKEN")),
332327
// Set python to be unbuffered so we get logs in real time
333328
"PYTHONUNBUFFERED=1",
334329
)

api/pkg/server/handlers.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func (apiServer *HelixAPIServer) createSession(res http.ResponseWriter, req *htt
223223
modelName = model.Model_Axolotl_Mistral7b
224224
}
225225
case types.SessionTypeImage:
226-
modelName = model.Model_Diffusers_SD35
226+
modelName = model.Model_Diffusers_SDTurbo
227227
}
228228

229229
sessionID := system.GenerateUUID()
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.10
1+
3.11

runner/helix-diffusers/main.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
import PIL
1111
import torch
12-
from diffusers.pipelines.stable_diffusion_3 import (
13-
StableDiffusion3Pipeline,
14-
)
12+
from diffusers import AutoPipelineForText2Image
1513
from fastapi import FastAPI, HTTPException
1614
from fastapi.middleware.cors import CORSMiddleware
1715
from fastapi.staticfiles import StaticFiles
@@ -25,7 +23,7 @@
2523
server_host = os.getenv("SERVER_HOST", "0.0.0.0")
2624
server_port = int(os.getenv("SERVER_PORT", 8000))
2725
server_url = f"http://{server_host}:{server_port}"
28-
model_id = os.getenv("MODEL_ID", "stabilityai/stable-diffusion-3.5-medium")
26+
model_id = os.getenv("MODEL_ID", "stabilityai/sd-turbo")
2927

3028

3129
class TextToImageInput(BaseModel):
@@ -46,16 +44,18 @@ def start(self, model_id: str):
4644
if torch.cuda.is_available():
4745
logger.info("Loading CUDA")
4846
self.device = "cuda"
49-
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
47+
self.pipeline = AutoPipelineForText2Image.from_pretrained(
5048
model_id,
5149
torch_dtype=torch.bfloat16,
50+
local_files_only=True,
5251
).to(device=self.device)
5352
elif torch.backends.mps.is_available():
5453
logger.info("Loading MPS for Mac M Series")
5554
self.device = "mps"
56-
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
55+
self.pipeline = AutoPipelineForText2Image.from_pretrained(
5756
model_id,
5857
torch_dtype=torch.bfloat16,
58+
local_files_only=True,
5959
).to(device=self.device)
6060
else:
6161
raise Exception("No CUDA or MPS device available")

runner/helix-diffusers/pyproject.toml

+11-3
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ name = "helix-diffusers"
33
version = "0.1.0"
44
description = "Add your description here"
55
readme = "README.md"
6-
requires-python = ">=3.10"
6+
requires-python = ">=3.11"
77
dependencies = [
88
"accelerate>=1.1.1",
99
"diffusers>=0.31.0",
1010
"fastapi>=0.115.5",
1111
"httpx>=0.27.2",
1212
"protobuf>=5.28.3",
1313
"sentencepiece>=0.2.0",
14-
"torch>=2.5.1",
14+
"torch==2.5.1+cu124",
1515
"transformers>=4.46.3",
1616
"uvicorn>=0.32.1",
1717
]
@@ -24,6 +24,14 @@ dev = [
2424
"ruff>=0.8.0",
2525
]
2626

27+
[tool.uv.sources]
28+
torch = { index = "pytorch-cu124" }
29+
30+
[[tool.uv.index]]
31+
name = "pytorch-cu124"
32+
url = "https://download.pytorch.org/whl/cu124"
33+
explicit = true
34+
2735
[tool.ruff]
2836
line-length = 100
2937
indent-width = 4
@@ -33,4 +41,4 @@ target-version = "py311"
3341
filterwarnings = [
3442
"ignore::UserWarning",
3543
"ignore::DeprecationWarning",
36-
]
44+
]

0 commit comments

Comments
 (0)