Skip to content

Commit 61e92db

Browse files
authored
feat(runner)!: NATS based runner communication architecture (#812)
This introduces a big re-architecture of the way we communicate with runners. Previously, runners were expected to reach out to the control plane. This meant the scheduler wasn't truly in control over what runners were doing and resulted in a variety of hard to manage race conditions. Background Now (with one exception noted below), the runner connects to the runner once to establish a web socket connection, which NATS then uses. All subsequent communication is RPC from the control plane via NATS. The runner is implemented as a simple web server with the usual REST semantics. So from the control plane, it's almost like doing a simple REST request to tell the runner what to do, which makes everything much simpler to manage and will ultimately be far more stable. Known Problems The various "upload files" paths to upload results to the control plane are the same. Image fine tuning wasn't implemented. (NTH) image generation request doesn't propagate session errors properly, probably need to wire that in like i did text (NTH) Slots aren't shown on the dash until models are running. But some of them take a while to load, like flux. Would be nice if they were shown immediately with a status. (NTH) Weird situation when hammering and the hammered nodes have a bit more free GPU left compared to a very stale node. It prefers to try and delete stales on this because there's more room. It's a scheduling strategy improvement. (NTH) More optimisation could be done on the cold start, since we're limiting to one new slot at a time for safety. Could lock over runner IDs so that it's one start per runner. BREAKING CHANGE: You must upgrade your runners and control plane to this version at the same time. Neither are backwards compatible.
1 parent d3dfdac commit 61e92db

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+6374
-8542
lines changed

.drone.yml

-14
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,6 @@ steps:
159159
- name: dockersocket
160160
path: /var/run/docker.sock
161161
when:
162-
branch:
163-
- main
164162
event:
165163
- tag
166164
- push
@@ -243,8 +241,6 @@ steps:
243241
- name: dockersocket
244242
path: /var/run/docker.sock
245243
when:
246-
branch:
247-
- main
248244
event:
249245
- tag
250246

@@ -316,8 +312,6 @@ steps:
316312
- name: dockersocket
317313
path: /var/run/docker.sock
318314
when:
319-
branch:
320-
- main
321315
event:
322316
- tag
323317

@@ -380,8 +374,6 @@ steps:
380374
- name: dockersocket
381375
path: /var/run/docker.sock
382376
when:
383-
branch:
384-
- main
385377
event:
386378
- tag
387379
- push
@@ -416,8 +408,6 @@ steps:
416408
- name: dockersocket
417409
path: /var/run/docker.sock
418410
when:
419-
branch:
420-
- main
421411
event:
422412
- tag
423413
- push
@@ -452,8 +442,6 @@ steps:
452442
- name: dockersocket
453443
path: /var/run/docker.sock
454444
when:
455-
branch:
456-
- main
457445
event:
458446
- tag
459447
- push
@@ -488,8 +476,6 @@ steps:
488476
- name: dockersocket
489477
path: /var/run/docker.sock
490478
when:
491-
branch:
492-
- main
493479
event:
494480
- tag
495481
- push

api/cmd/helix/qapairs.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ func newQapairCommand() *cobra.Command {
2424
if err != nil {
2525
return fmt.Errorf("failed to load server config: %v", err)
2626
}
27-
ps, err := pubsub.New(serverConfig.PubSub.StoreDir)
27+
ps, err := pubsub.New(&serverConfig)
28+
if err != nil {
29+
return err
30+
}
31+
scheduler, err := scheduler.NewScheduler(cmd.Context(), &serverConfig, nil)
2832
if err != nil {
2933
return err
3034
}
31-
scheduler := scheduler.NewScheduler(cmd.Context(), &serverConfig, nil)
3235
helixInference := openai.NewInternalHelixServer(&serverConfig, ps, scheduler)
3336
client, err := createDataPrepOpenAIClient(&serverConfig, helixInference)
3437
if err != nil {

api/cmd/helix/root.go

+1-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package helix
33
import (
44
"context"
55
"os"
6-
"runtime"
76

87
"github.com/spf13/cobra"
98

@@ -46,11 +45,7 @@ func NewRootCmd() *cobra.Command {
4645
RootCmd.AddCommand(newQapairCommand())
4746
RootCmd.AddCommand(newEvalsCommand())
4847
RootCmd.AddCommand(NewTestCmd()) // Use the NewTestCmd function from the current package
49-
50-
// Runner only works on Linux
51-
if runtime.GOOS == "linux" {
52-
RootCmd.AddCommand(newRunnerCmd())
53-
}
48+
RootCmd.AddCommand(newRunnerCmd())
5449

5550
return RootCmd
5651
}

api/cmd/helix/runner.go

+15-6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ func NewRunnerOptions() *RunnerOptions {
4343
AllowMultipleCopies: getDefaultServeOptionBool("ALLOW_MULTIPLE_COPIES", false),
4444
MaxModelInstances: getDefaultServeOptionInt("MAX_MODEL_INSTANCES", 0),
4545
CacheDir: getDefaultServeOptionString("CACHE_DIR", "/root/.cache/huggingface"), // TODO: change to maybe just /data
46+
WebServer: runner.WebServer{
47+
Host: getDefaultServeOptionString("SERVER_HOST", "127.0.0.1"),
48+
Port: getDefaultServeOptionInt("SERVER_PORT", 80),
49+
},
4650
},
4751
Janitor: config.Janitor{
4852
SentryDsnAPI: getDefaultServeOptionString("SENTRY_DSN_API", ""),
@@ -86,6 +90,16 @@ func newRunnerCmd() *cobra.Command {
8690
`The auth token for this runner`,
8791
)
8892

93+
runnerCmd.PersistentFlags().StringVar(
94+
&allOptions.Runner.WebServer.Host, "server-host", allOptions.Runner.WebServer.Host,
95+
`The host to bind the api server to.`,
96+
)
97+
98+
runnerCmd.PersistentFlags().IntVar(
99+
&allOptions.Runner.WebServer.Port, "server-port", allOptions.Runner.WebServer.Port,
100+
`The port to bind the api server to.`,
101+
)
102+
89103
runnerCmd.PersistentFlags().Uint64Var(
90104
&allOptions.Runner.MemoryBytes, "memory-bytes", allOptions.Runner.MemoryBytes,
91105
`The number of bytes of GPU memory available - e.g. 1073741824`,
@@ -288,12 +302,7 @@ func runnerCLI(cmd *cobra.Command, options *RunnerOptions) error {
288302
return err
289303
}
290304

291-
err = runnerController.Initialize(ctx)
292-
if err != nil {
293-
return err
294-
}
295-
296-
go runnerController.Run()
305+
go runnerController.Run(ctx)
297306

298307
<-ctx.Done()
299308
return nil

api/cmd/helix/serve.go

+44-46
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,9 @@ func serve(cmd *cobra.Command, cfg *config.ServerConfig) error {
219219
return err
220220
}
221221

222-
ps, err := pubsub.New(cfg.PubSub.StoreDir)
222+
ps, err := pubsub.New(cfg)
223223
if err != nil {
224-
return err
224+
return fmt.Errorf("failed to create pubsub provider: %w", err)
225225
}
226226

227227
if cfg.WebServer.RunnerToken == "" {
@@ -265,49 +265,50 @@ func serve(cmd *cobra.Command, cfg *config.ServerConfig) error {
265265
return fmt.Errorf("unknown extractor: %s", cfg.TextExtractor.Provider)
266266
}
267267

268-
// Must use the same allocator for both new LLM requests and old sessions
269-
scheduler := scheduler.NewScheduler(ctx, cfg, func(work *scheduler.Workload, err error) {
270-
// This function describes what happens when errors occur in jobs.
271-
// Each request type (session vs. LLM requests) has a differeht code path handling results,
272-
// hence for now we need to separate cases to handle errors.
273-
switch work.WorkloadType {
274-
case scheduler.WorkloadTypeLLMInferenceRequest:
275-
log.Warn().Err(err).Str("id", work.ID()).Msg("error scheduling work, removing from queue")
276-
req := work.LLMInferenceRequest()
277-
resp := &types.RunnerLLMInferenceResponse{
278-
RequestID: req.RequestID,
279-
OwnerID: req.OwnerID,
280-
SessionID: req.SessionID,
281-
InteractionID: req.InteractionID,
282-
Error: err.Error(),
283-
Done: true,
284-
}
285-
bts, err := json.Marshal(resp)
286-
if err != nil {
287-
log.Error().Err(err).Str("id", work.ID()).Msg("error marshalling runner response")
288-
}
268+
runnerController, err := scheduler.NewRunnerController(ctx, &scheduler.RunnerControllerConfig{
269+
PubSub: ps,
270+
FS: fs,
271+
})
272+
if err != nil {
273+
return err
274+
}
289275

290-
err = ps.Publish(context.Background(), pubsub.GetRunnerResponsesQueue(req.OwnerID, req.RequestID), bts)
291-
if err != nil {
292-
log.Error().Err(err).Str("id", work.ID()).Msg("error publishing runner response")
293-
}
294-
case scheduler.WorkloadTypeSession:
295-
// If we can't retry, write an error to the request and continue so it takes it off
296-
// the queue
297-
errSession := work.Session()
298-
errSession.Interactions = append(errSession.Interactions, &types.Interaction{
299-
Creator: types.CreatorTypeSystem,
300-
Error: err.Error(),
301-
Message: "Error scheduling session",
302-
})
303-
_, err = store.UpdateSession(ctx, *errSession)
304-
if err != nil {
305-
log.Error().Err(err).Msg("error updating session")
276+
var appController *controller.Controller
277+
278+
scheduler, err := scheduler.NewScheduler(ctx, cfg, &scheduler.Params{
279+
RunnerController: runnerController,
280+
QueueSize: 100,
281+
OnSchedulingErr: func(work *scheduler.Workload, err error) {
282+
if appController != nil {
283+
switch work.WorkloadType {
284+
case scheduler.WorkloadTypeLLMInferenceRequest:
285+
request := work.LLMInferenceRequest()
286+
response := types.RunnerNatsReplyResponse{
287+
OwnerID: request.OwnerID,
288+
RequestID: request.RequestID,
289+
Error: err.Error(),
290+
Response: []byte{},
291+
}
292+
bts, err := json.Marshal(response)
293+
if err != nil {
294+
log.Error().Err(err).Msg("error marshalling runner response")
295+
}
296+
err = ps.Publish(ctx, pubsub.GetRunnerResponsesQueue(request.OwnerID, request.RequestID), bts)
297+
if err != nil {
298+
log.Error().Err(err).Msg("error publishing runner response")
299+
}
300+
case scheduler.WorkloadTypeSession:
301+
appController.ErrorSession(ctx, work.Session(), err)
302+
}
306303
}
307-
default:
308-
log.Error().Str("workload_type", string(work.WorkloadType)).Msg("unknown workload type")
309-
}
304+
},
305+
OnResponseHandler: func(_ context.Context, _ *types.RunnerLLMInferenceResponse) error {
306+
return nil
307+
},
310308
})
309+
if err != nil {
310+
return err
311+
}
311312

312313
helixInference := openai.NewInternalHelixServer(cfg, ps, scheduler)
313314

@@ -354,8 +355,6 @@ func serve(cmd *cobra.Command, cfg *config.ServerConfig) error {
354355
return fmt.Errorf("unknown RAG provider: %s", cfg.RAG.DefaultRagProvider)
355356
}
356357

357-
var appController *controller.Controller
358-
359358
controllerOptions := controller.Options{
360359
Config: cfg,
361360
Store: store,
@@ -369,6 +368,7 @@ func serve(cmd *cobra.Command, cfg *config.ServerConfig) error {
369368
ProviderManager: providerManager,
370369
DataprepOpenAIClient: dataprepOpenAIClient,
371370
Scheduler: scheduler,
371+
RunnerController: runnerController,
372372
}
373373

374374
appController, err = controller.NewController(ctx, controllerOptions)
@@ -381,8 +381,6 @@ func serve(cmd *cobra.Command, cfg *config.ServerConfig) error {
381381
return err
382382
}
383383

384-
go appController.Start(ctx)
385-
386384
// Initialize browser pool
387385
browserPool, err := browser.New(cfg)
388386
if err != nil {

api/pkg/config/config.go

+9
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,15 @@ type FileStore struct {
248248

249249
type PubSub struct {
250250
StoreDir string `envconfig:"NATS_STORE_DIR" default:"/filestore/nats" description:"The directory to store nats data."`
251+
Provider string `envconfig:"PUBSUB_PROVIDER" default:"nats" description:"The pubsub provider to use (nats or inmemory)."`
252+
Server struct {
253+
EmbeddedNatsServerEnabled bool `envconfig:"NATS_SERVER_EMBEDDED_ENABLED" default:"true" description:"Whether to enable the embedded NATS server."`
254+
Host string `envconfig:"NATS_SERVER_HOST" default:"127.0.0.1" description:"The host to bind the NATS server to."`
255+
Port int `envconfig:"NATS_SERVER_PORT" default:"8433" description:"The port to bind the NATS server to."`
256+
Token string `envconfig:"NATS_SERVER_TOKEN" description:"The authentication token for the NATS server."`
257+
MaxPayload int `envconfig:"NATS_SERVER_MAX_PAYLOAD" default:"33554432" description:"The maximum payload size in bytes (default 32MB)."`
258+
JetStream bool `envconfig:"NATS_SERVER_JETSTREAM" default:"true" description:"Whether to enable JetStream."`
259+
}
251260
}
252261

253262
type Store struct {

api/pkg/controller/controller.go

+3-36
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package controller
33
import (
44
"context"
55
"fmt"
6-
"runtime/debug"
7-
"time"
86

97
"github.com/helixml/helix/api/pkg/config"
108
"github.com/helixml/helix/api/pkg/extract"
@@ -21,8 +19,6 @@ import (
2119
"github.com/helixml/helix/api/pkg/store"
2220
"github.com/helixml/helix/api/pkg/tools"
2321
"github.com/helixml/helix/api/pkg/types"
24-
"github.com/puzpuzpuz/xsync/v3"
25-
"github.com/rs/zerolog/log"
2622
)
2723

2824
type Options struct {
@@ -38,7 +34,8 @@ type Options struct {
3834
// OpenAIClient openai.Client
3935
ProviderManager manager.ProviderManager
4036
DataprepOpenAIClient openai.Client
41-
Scheduler scheduler.Scheduler
37+
Scheduler *scheduler.Scheduler
38+
RunnerController *scheduler.RunnerController
4239
}
4340

4441
type Controller struct {
@@ -56,14 +53,10 @@ type Controller struct {
5653
// the models package looks after instantiating this for us
5754
models map[string]model.Model
5855

59-
// the map of model instances that we have loaded
60-
// and are currently running
61-
activeRunners *xsync.MapOf[string, *types.RunnerState]
62-
6356
// the current buffer of scheduling decisions
6457
schedulingDecisions []*types.GlobalSchedulingDecision
6558

66-
scheduler scheduler.Scheduler
59+
scheduler *scheduler.Scheduler
6760
}
6861

6962
func NewController(
@@ -100,7 +93,6 @@ func NewController(
10093
newRagClient: func(settings *types.RAGSettings) rag.RAG {
10194
return rag.NewLlamaindex(settings)
10295
},
103-
activeRunners: xsync.NewMapOf[string, *types.RunnerState](),
10496
schedulingDecisions: []*types.GlobalSchedulingDecision{},
10597
scheduler: options.Scheduler,
10698
}
@@ -123,28 +115,3 @@ func NewController(
123115
func (c *Controller) Initialize() error {
124116
return nil
125117
}
126-
127-
// this should be run in a go-routine
128-
func (c *Controller) Start(ctx context.Context) {
129-
for {
130-
select {
131-
case <-ctx.Done():
132-
return
133-
case <-time.After(10 * time.Second):
134-
err := c.run(c.Ctx)
135-
if err != nil {
136-
log.Error().Msgf("error in controller loop: %s", err.Error())
137-
debug.PrintStack()
138-
}
139-
}
140-
}
141-
}
142-
143-
func (c *Controller) run(ctx context.Context) error {
144-
err := c.cleanOldRunnerMetrics(ctx)
145-
if err != nil {
146-
log.Error().Msgf("error in controller loop: %s", err.Error())
147-
debug.PrintStack()
148-
}
149-
return nil
150-
}

0 commit comments

Comments
 (0)