Skip to content

Commit

Permalink
Simplify further
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 21, 2024
1 parent d49cb23 commit d9e5ecc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 36 deletions.
6 changes: 1 addition & 5 deletions dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,7 @@ func Client(client *dispatchclient.Client) Option {

// Register registers a function.
func (d *Dispatch) Register(fn AnyFunction) {
d.RegisterPrimitive(fn.Name(), fn.Primitive())

// Bind the function to this endpoint, so that the function's
// Dispatch method can be used to dispatch calls.
fn.register(d)
d.RegisterPrimitive(fn.Register(d))
}

// RegisterPrimitive registers a primitive function.
Expand Down
3 changes: 2 additions & 1 deletion dispatchlambda/lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ func Start(functions ...dispatch.AnyFunction) {
func Handler(functions ...dispatch.AnyFunction) lambda.Handler {
handler := &handler{functions: dispatchproto.FunctionMap{}}
for _, fn := range functions {
handler.functions[fn.Name()] = fn.Primitive()
name, primitive := fn.Register(nil)
handler.functions[name] = primitive
}
return handler
}
Expand Down
37 changes: 30 additions & 7 deletions dispatchtest/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ import (

// Run runs a function and returns its result.
func Run[O any](call dispatchproto.Call, functions ...dispatch.AnyFunction) (O, error) {
runner := Runner{Functions: dispatchproto.FunctionMap{}}
for _, fn := range functions {
runner.Functions[fn.Name()] = fn.Primitive()
}
runner := NewRunner(functions...)

res := runner.Run(call.Request())

var output O
Expand Down Expand Up @@ -42,20 +40,45 @@ func Run[O any](call dispatchproto.Call, functions ...dispatch.AnyFunction) (O,

// Runner runs functions.
type Runner struct {
Functions dispatchproto.FunctionMap
functions dispatchproto.FunctionMap
}

// NewRunner creates a Runner.
func NewRunner(functions ...dispatch.AnyFunction) *Runner {
runner := &Runner{functions: dispatchproto.FunctionMap{}}
for _, fn := range functions {
runner.Register(fn)
}
return runner
}

// Run runs a function and returns its response.
// Register registers a function.
func (r *Runner) Register(fn dispatch.AnyFunction) {
name, primitive := fn.Register(nil)
r.RegisterPrimitive(name, primitive)
}

// RegisterPrimitive registers a primitive function.
func (r *Runner) RegisterPrimitive(name string, fn dispatchproto.Function) {
r.functions[name] = fn
}

// Run runs a function to completion and returns its response.
func (r *Runner) Run(req dispatchproto.Request) dispatchproto.Response {
for {
res := r.Functions.Run(context.Background(), req)
res := r.RoundTrip(req)
if _, ok := res.Exit(); ok {
return res
}
req = r.poll(req, res)
}
}

// RoundTrip sends a request to a function and returns its response.
func (r Runner) RoundTrip(req dispatchproto.Request) dispatchproto.Response {
return r.functions.Run(context.Background(), req)
}

func (r *Runner) poll(req dispatchproto.Request, res dispatchproto.Response) dispatchproto.Request {
poll, ok := res.Poll()
if !ok {
Expand Down
22 changes: 6 additions & 16 deletions function.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ func (f *Function[I, O]) Dispatch(ctx context.Context, input I, opts ...dispatch
return client.Dispatch(ctx, call)
}

// Primitive returns the associated primitive function.
func (f *Function[I, O]) Primitive() dispatchproto.Function {
return f.run
}

func (f *Function[I, O]) run(ctx context.Context, req dispatchproto.Request) dispatchproto.Response {
if name := req.Function(); name != f.name {
return dispatchproto.NewResponseErrorf("%w: function %q received call for function %q", ErrInvalidArgument, f.name, name)
Expand Down Expand Up @@ -183,8 +178,11 @@ func (f *Function[I, O]) deserialize(state dispatchproto.Any) (coroutineID, disp
return id, coro, err
}

func (f *Function[I, O]) register(endpoint *Dispatch) {
// Register is called when the function is registered
// on a Dispatch endpoint.
func (f *Function[I, O]) Register(endpoint *Dispatch) (string, dispatchproto.Function) {
f.endpoint = endpoint
return f.name, f.run
}

func (c *Function[I, O]) entrypoint(input I) func() dispatchproto.Response {
Expand Down Expand Up @@ -238,17 +236,9 @@ func (f *Function[I, O]) Gather(inputs []I, opts ...dispatchproto.CallOption) ([
return dispatchcoro.Gather[O](calls...)
}

// AnyFunction is the interface implemented by all Function[I, O] instances.
// AnyFunction is a Function[I, O] instance.
type AnyFunction interface {
// Name is the name of the function.
Name() string

// Primitive is the primitive dispatchproto.Function.
Primitive() dispatchproto.Function

// register is an internal hook which binds the function to
// a Dispatch endpoint, allowing its Dispatch method to be called.
register(*Dispatch)
Register(*Dispatch) (string, dispatchproto.Function)
}

// "Instances" are only applicable when coroutines are running
Expand Down
22 changes: 15 additions & 7 deletions function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,13 @@ func TestCoroutinePoll(t *testing.T) {
return repeated, nil
})

runner := dispatchtest.NewRunner(repeat)

// Continously run the coroutine until it returns/exits.
var req dispatchproto.Request = dispatchproto.NewRequest("repeat", dispatchproto.Int(3))
var res dispatchproto.Response
for {
res = repeat.Primitive()(context.Background(), req)
res = runner.RoundTrip(req)
if res.Status() != dispatchproto.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
Expand Down Expand Up @@ -226,12 +228,14 @@ func TestCoroutineAwait(t *testing.T) {

const repeatCount = 3

runner := dispatchtest.NewRunner(repeat)

req := dispatchproto.NewRequest("repeat", dispatchproto.Int(repeatCount))
var res dispatchproto.Response

requestCount := 0
for {
res = repeat.Primitive()(context.Background(), req)
res = runner.RoundTrip(req)
if res.Status() != dispatchproto.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
Expand Down Expand Up @@ -309,8 +313,10 @@ func TestCoroutineGather(t *testing.T) {

const repeatCount = 3

runner := dispatchtest.NewRunner(repeat)

req := dispatchproto.NewRequest("repeat", dispatchproto.Int(repeatCount))
res := repeat.Primitive()(context.Background(), req)
res := runner.RoundTrip(req)
if res.Status() != dispatchproto.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
Expand All @@ -337,7 +343,7 @@ func TestCoroutineGather(t *testing.T) {
dispatchproto.CallResults(callResults...))

req = dispatchproto.NewRequest("repeat", pollResult)
res = repeat.Primitive()(context.Background(), req)
res = runner.RoundTrip(req)
if res.Status() != dispatchproto.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
Expand Down Expand Up @@ -387,8 +393,10 @@ func TestCoroutineGatherSlow(t *testing.T) {

const repeatCount = 3

runner := dispatchtest.NewRunner(repeat)

req := dispatchproto.NewRequest("repeat", dispatchproto.Int(repeatCount))
res := repeat.Primitive()(context.Background(), req)
res := runner.RoundTrip(req)
if res.Status() != dispatchproto.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
Expand Down Expand Up @@ -416,7 +424,7 @@ func TestCoroutineGatherSlow(t *testing.T) {

// Deliver an empty poll result, to assert it's a noop.
req = dispatchproto.NewRequest("repeat", poll.Result())
res = repeat.Primitive()(context.Background(), req)
res = runner.RoundTrip(req)
if res.Status() != dispatchproto.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
Expand All @@ -430,7 +438,7 @@ func TestCoroutineGatherSlow(t *testing.T) {
pollResult := poll.Result().With(dispatchproto.CallResults(callResults[i]))

req = dispatchproto.NewRequest("repeat", pollResult)
res = repeat.Primitive()(context.Background(), req)
res = runner.RoundTrip(req)
if res.Status() != dispatchproto.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
Expand Down

0 comments on commit d9e5ecc

Please sign in to comment.