Skip to content

Commit

Permalink
Move more internal logic to dispatchcoro
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 21, 2024
1 parent d9e5ecc commit 7750571
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 76 deletions.
5 changes: 5 additions & 0 deletions dispatchcoro/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import (
// Coroutine is the flavour of coroutine supported by Dispatch and the SDK.
type Coroutine = coroutine.Coroutine[dispatchproto.Response, dispatchproto.Request]

// New creates a Coroutine.
func New(fn func() dispatchproto.Response) Coroutine {
return coroutine.NewWithReturn[dispatchproto.Response, dispatchproto.Request](fn)
}

// Yield yields control to Dispatch.
//
// The coroutine is suspended while the Response is sent to Dispatch.
Expand Down
75 changes: 75 additions & 0 deletions dispatchcoro/volatile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package dispatchcoro

import (
"fmt"
"math/rand/v2"
"sync"
)

// VolatileCoroutines is a set of volatile coroutine instances.
//
// "Instances" are only applicable when coroutines are running
// in volatile mode, since suspended coroutines must be kept in
// memory. In durable mode, there's no need to keep instances
// around, since they can be serialized and later recreated.
type VolatileCoroutines struct {
instances map[InstanceID]Coroutine
nextID InstanceID
mu sync.Mutex
}

// InstanceID is a unique identifier for a coroutine instance.
type InstanceID = uint64

// Register registers a coroutine instance and returns a unique
// identifier.
func (f *VolatileCoroutines) Register(coro Coroutine) InstanceID {
f.mu.Lock()
defer f.mu.Unlock()

if f.nextID == 0 {
f.nextID = rand.Uint64()
}
f.nextID++

id := f.nextID
if f.instances == nil {
f.instances = map[InstanceID]Coroutine{}
}
f.instances[id] = coro

return id
}

// Find finds the coroutine instance with the specified ID.
func (f *VolatileCoroutines) Find(id InstanceID) (Coroutine, error) {
f.mu.Lock()
defer f.mu.Unlock()

coro, ok := f.instances[id]
if !ok {
return coro, fmt.Errorf("volatile coroutine %d not found", id)
}
return coro, nil
}

// Delete deletes a coroutine instance.
func (f *VolatileCoroutines) Delete(id InstanceID) {
f.mu.Lock()
defer f.mu.Unlock()

delete(f.instances, id)
}

// Close closes the set of coroutine instances.
func (f *VolatileCoroutines) Close() error {
f.mu.Lock()
defer f.mu.Unlock()

for _, fn := range f.instances {
fn.Stop()
fn.Next()
}
clear(f.instances)
return nil
}
88 changes: 12 additions & 76 deletions function.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ package dispatch
import (
"context"
"fmt"
"math/rand/v2"
"slices"
"sync"

"github.com/dispatchrun/coroutine"
"github.com/dispatchrun/dispatch-go/dispatchcoro"
Expand All @@ -27,7 +25,7 @@ type Function[I, O any] struct {

endpoint *Dispatch

volatileCoroutines
instances dispatchcoro.VolatileCoroutines
}

// Name is the name of the function.
Expand Down Expand Up @@ -103,7 +101,7 @@ func (f *Function[I, O]) run(ctx context.Context, req dispatchproto.Request) dis
return yield.With(dispatchproto.CoroutineState(state))
}

func (f *Function[I, O]) setUp(req dispatchproto.Request) (coroutineID, dispatchcoro.Coroutine, error) {
func (f *Function[I, O]) setUp(req dispatchproto.Request) (dispatchcoro.InstanceID, dispatchcoro.Coroutine, error) {
// If the request carries a poll result, find/deserialize the
// suspended coroutine.
if pollResult, ok := req.PollResult(); ok {
Expand All @@ -119,20 +117,20 @@ func (f *Function[I, O]) setUp(req dispatchproto.Request) (coroutineID, dispatch
if err := boxedInput.Unmarshal(&input); err != nil {
return 0, dispatchcoro.Coroutine{}, fmt.Errorf("%w: invalid input %v: %v", ErrInvalidArgument, boxedInput, err)
}
coro := coroutine.NewWithReturn[dispatchproto.Response, dispatchproto.Request](f.entrypoint(input))
coro := dispatchcoro.New(f.entrypoint(input))

// In volatile mode, register the coroutine instance and assign a unique ID.
var id coroutineID
var id dispatchcoro.InstanceID
if !coroutine.Durable {
id = f.registerCoroutineInstance(coro)
id = f.instances.Register(coro)
}
return id, coro, nil
}

func (f *Function[I, O]) tearDown(id coroutineID, coro dispatchcoro.Coroutine) {
func (f *Function[I, O]) tearDown(id dispatchcoro.InstanceID, coro dispatchcoro.Coroutine) {
// Remove volatile coroutine instances only once they're done.
if !coroutine.Durable && coro.Done() {
f.volatileCoroutines.deleteCoroutineInstance(id)
f.instances.Delete(id)
}

// Always tear down durable coroutines. They'll be rebuilt
Expand All @@ -144,7 +142,7 @@ func (f *Function[I, O]) tearDown(id coroutineID, coro dispatchcoro.Coroutine) {
}
}

func (f *Function[I, O]) serialize(id coroutineID, coro dispatchcoro.Coroutine) (dispatchproto.Any, error) {
func (f *Function[I, O]) serialize(id dispatchcoro.InstanceID, coro dispatchcoro.Coroutine) (dispatchproto.Any, error) {
// In volatile mode, serialize a reference to the coroutine instance.
if !coroutine.Durable {
return dispatchproto.NewAny(id)
Expand All @@ -158,23 +156,23 @@ func (f *Function[I, O]) serialize(id coroutineID, coro dispatchcoro.Coroutine)
return state, nil
}

func (f *Function[I, O]) deserialize(state dispatchproto.Any) (coroutineID, dispatchcoro.Coroutine, error) {
func (f *Function[I, O]) deserialize(state dispatchproto.Any) (dispatchcoro.InstanceID, dispatchcoro.Coroutine, error) {
// In durable mode, create the coroutine and then deserialize its prior state.
if coroutine.Durable {
var zero I
coro := coroutine.NewWithReturn[dispatchproto.Response, dispatchproto.Request](f.entrypoint(zero))
coro := dispatchcoro.New(f.entrypoint(zero))
if err := dispatchcoro.Deserialize(coro, state); err != nil {
return 0, dispatchcoro.Coroutine{}, fmt.Errorf("%w: %v", ErrIncompatibleState, err)
}
return 0, coro, nil
}

// In volatile mode, find the suspended coroutine instance.
var id coroutineID
var id dispatchcoro.InstanceID
if err := state.Unmarshal(&id); err != nil {
return 0, dispatchcoro.Coroutine{}, fmt.Errorf("%w: invalid volatile coroutine reference: %s", ErrIncompatibleState, state)
}
coro, err := f.findCoroutineInstance(id)
coro, err := f.instances.Find(id)
return id, coro, err
}

Expand Down Expand Up @@ -240,65 +238,3 @@ func (f *Function[I, O]) Gather(inputs []I, opts ...dispatchproto.CallOption) ([
type AnyFunction interface {
Register(*Dispatch) (string, dispatchproto.Function)
}

// "Instances" are only applicable when coroutines are running
// in volatile mode, since we must be keep suspended coroutines in
// memory while they're polling. In durable mode, there's no need
// to keep "instances" around, since we can serialize the state of
// each coroutine and send it back and forth to Dispatch. In durable
// mode Function[I,O] is stateless.
type volatileCoroutines struct {
instances map[coroutineID]dispatchcoro.Coroutine
nextID coroutineID
mu sync.Mutex
}

type coroutineID = uint64

func (f *volatileCoroutines) registerCoroutineInstance(coro dispatchcoro.Coroutine) coroutineID {
f.mu.Lock()
defer f.mu.Unlock()

if f.nextID == 0 {
f.nextID = rand.Uint64()
}
f.nextID++

id := f.nextID
if f.instances == nil {
f.instances = map[coroutineID]dispatchcoro.Coroutine{}
}
f.instances[id] = coro

return id
}

func (f *volatileCoroutines) findCoroutineInstance(id coroutineID) (dispatchcoro.Coroutine, error) {
f.mu.Lock()
defer f.mu.Unlock()

coro, ok := f.instances[id]
if !ok {
return coro, fmt.Errorf("%w: volatile coroutine %d not found", ErrIncompatibleState, id)
}
return coro, nil
}

func (f *volatileCoroutines) deleteCoroutineInstance(id coroutineID) {
f.mu.Lock()
defer f.mu.Unlock()

delete(f.instances, id)
}

func (f *volatileCoroutines) Close() error {
f.mu.Lock()
defer f.mu.Unlock()

for _, fn := range f.instances {
fn.Stop()
fn.Next()
}
clear(f.instances)
return nil
}

0 comments on commit 7750571

Please sign in to comment.