Skip to content

Commit 08830e1

Browse files
committed
feat(runner): add basic diffusers server
1 parent 6bc8777 commit 08830e1

14 files changed

+2483
-8
lines changed

Dockerfile.runner

+11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
ARG TAG=2024-11-21a-empty
44

5+
FROM ghcr.io/astral-sh/uv:0.5.4 as uv
6+
57
### BUILD
68

79
FROM golang:1.22 AS go-build-env
@@ -45,6 +47,15 @@ WORKDIR /workspace/helix
4547
# Copy runner directory from the repo
4648
COPY runner ./runner
4749

50+
# We need to set this environment variable so that uv knows where
51+
# the virtual environment is to install packages
52+
ENV UV_PROJECT_ENVIRONMENT=/workspace/helix/runner/helix-diffusers/venv
53+
54+
# Install the packages with uv using --mount=type=cache to cache the downloaded packages
55+
RUN --mount=type=cache,target=/root/.cache/uv \
56+
--mount=from=uv,source=/uv,target=/usr/bin/uv \
57+
cd /workspace/helix/runner/helix-diffusers && uv sync --no-dev
58+
4859
# Copy the cog wrapper, cog and cog-sdxl is installed in the base image, this is just the cog server
4960
COPY cog/helix_cog_wrapper.py /workspace/cog-sdxl/helix_cog_wrapper.py
5061

api/pkg/model/diffusers_generic.go

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package model
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os/exec"
7+
8+
"github.com/helixml/helix/api/pkg/types"
9+
)
10+
11+
var _ Model = &DiffusersGenericImage{}
12+
13+
type DiffusersGenericImage struct {
14+
Id string // e.g. "stabilityai/stable-diffusion-3.5-medium"
15+
Name string // e.g. "Stable Diffusion 3.5 Medium"
16+
Memory uint64
17+
Description string
18+
Hide bool
19+
}
20+
21+
func (i *DiffusersGenericImage) GetMemoryRequirements(mode types.SessionMode) uint64 {
22+
return i.Memory
23+
}
24+
25+
func (i *DiffusersGenericImage) GetType() types.SessionType {
26+
return types.SessionTypeImage
27+
}
28+
29+
func (i *DiffusersGenericImage) GetID() string {
30+
return i.Id
31+
}
32+
33+
func (i *DiffusersGenericImage) ModelName() ModelName {
34+
return NewModel(i.Id)
35+
}
36+
37+
func (i *DiffusersGenericImage) GetTask(session *types.Session, fileManager ModelSessionFileManager) (*types.RunnerTask, error) {
38+
task, err := getGenericTask(session)
39+
if err != nil {
40+
return nil, err
41+
}
42+
43+
return task, nil
44+
}
45+
46+
func (i *DiffusersGenericImage) GetCommand(ctx context.Context, sessionFilter types.SessionFilter, config types.RunnerProcessConfig) (*exec.Cmd, error) {
47+
return nil, fmt.Errorf("not implemented 1")
48+
}
49+
50+
func (i *DiffusersGenericImage) GetTextStreams(mode types.SessionMode, eventHandler WorkerEventHandler) (*TextStream, *TextStream, error) {
51+
return nil, nil, fmt.Errorf("not implemented 2")
52+
}
53+
54+
func (i *DiffusersGenericImage) PrepareFiles(session *types.Session, isInitialSession bool, fileManager ModelSessionFileManager) (*types.Session, error) {
55+
return nil, fmt.Errorf("not implemented 3")
56+
}
57+
58+
func (i *DiffusersGenericImage) GetDescription() string {
59+
return i.Description
60+
}
61+
62+
func (i *DiffusersGenericImage) GetHumanReadableName() string {
63+
return i.Name
64+
}
65+
66+
func (i *DiffusersGenericImage) GetHidden() bool {
67+
return i.Hide
68+
}

api/pkg/model/models.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ func (m ModelName) InferenceRuntime() types.InferenceRuntime {
5353
if m.String() == Model_Cog_SDXL {
5454
return types.InferenceRuntimeCog
5555
}
56+
diffusersModels, err := GetDefaultDiffusersModels()
57+
if err != nil {
58+
return types.InferenceRuntimeAxolotl
59+
}
60+
for _, model := range diffusersModels {
61+
if m.String() == model.Id {
62+
return types.InferenceRuntimeDiffusers
63+
}
64+
}
65+
5666
// misnamed: axolotl runtime handles axolotl and cog/sd-scripts
5767
return types.InferenceRuntimeAxolotl
5868
}
@@ -112,7 +122,7 @@ func ProcessModelName(
112122
}
113123
}
114124
case types.SessionTypeImage:
115-
return Model_Cog_SDXL, nil
125+
return Model_Diffusers_SD35, nil
116126
}
117127

118128
// shouldn't get here
@@ -133,12 +143,20 @@ func GetModels() (map[string]Model, error) {
133143
for _, model := range ollamaModels {
134144
models[model.Id] = model
135145
}
146+
diffusersModels, err := GetDefaultDiffusersModels()
147+
if err != nil {
148+
return nil, err
149+
}
150+
for _, model := range diffusersModels {
151+
models[model.Id] = model
152+
}
136153
return models, nil
137154
}
138155

139156
const (
140157
Model_Axolotl_Mistral7b string = "mistralai/Mistral-7B-Instruct-v0.1"
141158
Model_Cog_SDXL string = "stabilityai/stable-diffusion-xl-base-1.0"
159+
Model_Diffusers_SD35 string = "stabilityai/stable-diffusion-3.5-medium"
142160

143161
// We only need constants for _some_ ollama models that are hardcoded in
144162
// various places (backward compat). Other ones can be added dynamically now.
@@ -149,6 +167,18 @@ const (
149167
Model_Ollama_Phi3 string = "phi3:instruct"
150168
)
151169

170+
func GetDefaultDiffusersModels() ([]*DiffusersGenericImage, error) {
171+
return []*DiffusersGenericImage{
172+
{
173+
Id: Model_Diffusers_SD35,
174+
Name: "Stable Diffusion 3.5 Medium",
175+
Memory: GB * 21,
176+
Description: "Medium model, from Stability AI",
177+
Hide: false,
178+
},
179+
}, nil
180+
}
181+
152182
// See also types/models.go for model name constants
153183
func GetDefaultOllamaModels() ([]*OllamaGenericText, error) {
154184
models := []*OllamaGenericText{

0 commit comments

Comments
 (0)