diff --git a/README.md b/README.md index d301079..788d863 100644 --- a/README.md +++ b/README.md @@ -26,17 +26,17 @@ import ( func main() { ctx := context.Background() - p1 := promise4g.New(ctx, func(resolve func(string), reject func(error)) { + p1 := promise4g.New(func(resolve func(string), reject func(error)) { time.Sleep(100 * time.Millisecond) resolve("one") }) - p2 := promise4g.New(ctx, func(resolve func(string), reject func(error)) { + p2 := promise4g.New(func(resolve func(string), reject func(error)) { time.Sleep(200 * time.Millisecond) resolve("two") }) - p3 := promise4g.New(ctx, func(resolve func(string), reject func(error)) { + p3 := promise4g.New(func(resolve func(string), reject func(error)) { time.Sleep(300 * time.Millisecond) resolve("three") }) diff --git a/promise.go b/promise.go index de29e91..dbadf37 100644 --- a/promise.go +++ b/promise.go @@ -4,36 +4,33 @@ import ( "context" "fmt" "sync" + "sync/atomic" + "time" ) // Promise represents a computation that will eventually be completed with a value of type T or an error. type Promise[T any] struct { - value T - err error - ch chan struct{} + value atomic.Value + err atomic.Value + done chan struct{} once sync.Once } -func New[T any]( - task func(resolve func(T), reject func(error))) *Promise[T] { +// New creates a new Promise with the given task +func New[T any](task func(resolve func(T), reject func(error))) *Promise[T] { return NewWithPool(task, defaultPool) } -func NewWithPool[T any]( - task func(resolve func(T), reject func(error)), - pool Pool) *Promise[T] { +// NewWithPool creates a new Promise with the given task and pool +func NewWithPool[T any](task func(resolve func(T), reject func(error)), pool Pool) *Promise[T] { if task == nil { panic("task must not be nil") } if pool == nil { panic("pool must not be nil") } - var t T p := &Promise[T]{ - value: t, - err: nil, - ch: make(chan struct{}), - once: sync.Once{}, + done: make(chan struct{}), } pool.Go(func() { defer p.handlePanic() @@ -42,103 +39,105 @@ func NewWithPool[T any]( return p } +// Await waits for the Promise to be resolved or rejected func (p *Promise[T]) Await(ctx context.Context) (T, error) { select { case <-ctx.Done(): var t T return t, ctx.Err() - case <-p.ch: - return p.value, p.err + case <-p.done: + if err := p.err.Load(); err != nil { + var t T + return t, err.(error) + } + return p.value.Load().(T), nil } } func (p *Promise[T]) resolve(value T) { p.once.Do(func() { - p.value = value - close(p.ch) + p.value.Store(value) + close(p.done) }) } func (p *Promise[T]) reject(err error) { p.once.Do(func() { - p.err = err - close(p.ch) + p.err.Store(err) + close(p.done) }) } func (p *Promise[T]) handlePanic() { - err := recover() - if err == nil { - return - } - switch v := err.(type) { - case error: - p.reject(v) - default: - p.reject(fmt.Errorf("%+v", v)) + if r := recover(); r != nil { + var err error + switch v := r.(type) { + case error: + err = v + default: + err = fmt.Errorf("%v", v) + } + p.reject(err) } } -func All[T any]( - ctx context.Context, - promises ...*Promise[T], -) *Promise[[]T] { +// All waits for all promises to be resolved +func All[T any](ctx context.Context, promises ...*Promise[T]) *Promise[[]T] { return AllWithPool(ctx, defaultPool, promises...) } -func AllWithPool[T any]( - ctx context.Context, - pool Pool, - promises ...*Promise[T], -) *Promise[[]T] { +// AllWithPool waits for all promises to be resolved using the given pool +func AllWithPool[T any](ctx context.Context, pool Pool, promises ...*Promise[T]) *Promise[[]T] { if len(promises) == 0 { panic("missing promises") } return NewWithPool(func(resolve func([]T), reject func(error)) { - resultsChan := make(chan tuple[T, int], len(promises)) - errsChan := make(chan error, len(promises)) - - for idx, p := range promises { - idx := idx - _ = ThenWithPool(p, ctx, func(data T) (T, error) { - resultsChan <- tuple[T, int]{_1: data, _2: idx} - return data, nil - }, pool) - _ = CatchWithPool(p, ctx, func(err error) error { - errsChan <- err - return err - }, pool) - } - results := make([]T, len(promises)) - for idx := 0; idx < len(promises); idx++ { - select { - case result := <-resultsChan: - results[result._2] = result._1 - case err := <-errsChan: - reject(err) - return - } + var wg sync.WaitGroup + wg.Add(len(promises)) + + for i, p := range promises { + i, p := i, p + pool.Go(func() { + defer wg.Done() + result, err := p.Await(ctx) + if err != nil { + reject(err) + return + } + results[i] = result + }) } + + wg.Wait() resolve(results) }, pool) } -func Then[A, B any]( - p *Promise[A], - ctx context.Context, - resolve func(A) (B, error), -) *Promise[B] { +// Race returns a promise that resolves or rejects as soon as one of the promises resolves or rejects +func Race[T any](ctx context.Context, promises ...*Promise[T]) *Promise[T] { + return NewWithPool(func(resolve func(T), reject func(error)) { + for _, p := range promises { + go func(p *Promise[T]) { + result, err := p.Await(ctx) + if err != nil { + reject(err) + } else { + resolve(result) + } + }(p) + } + }, defaultPool) +} + +// Then chains a new Promise to the current one +func Then[A, B any](p *Promise[A], ctx context.Context, resolve func(A) (B, error)) *Promise[B] { return ThenWithPool(p, ctx, resolve, defaultPool) } -func ThenWithPool[A, B any]( - p *Promise[A], - ctx context.Context, - resolve func(A) (B, error), - pool Pool, -) *Promise[B] { +// ThenWithPool chains a new Promise to the current one using the given pool +func ThenWithPool[A, B any](p *Promise[A], ctx context.Context, resolve func(A) (B, error), pool Pool) *Promise[B] { return NewWithPool(func(resolveB func(B), reject func(error)) { result, err := p.Await(ctx) if err != nil { @@ -156,20 +155,13 @@ func ThenWithPool[A, B any]( }, pool) } -func Catch[T any]( - p *Promise[T], - ctx context.Context, - reject func(err error) error, -) *Promise[T] { +// Catch handles errors in the Promise chain +func Catch[T any](p *Promise[T], ctx context.Context, reject func(error) error) *Promise[T] { return CatchWithPool(p, ctx, reject, defaultPool) } -func CatchWithPool[T any]( - p *Promise[T], - ctx context.Context, - reject func(err error) error, - pool Pool, -) *Promise[T] { +// CatchWithPool handles errors in the Promise chain using the given pool +func CatchWithPool[T any](p *Promise[T], ctx context.Context, reject func(error) error, pool Pool) *Promise[T] { return NewWithPool(func(resolve func(T), internalReject func(error)) { result, err := p.Await(ctx) if err != nil { @@ -180,7 +172,29 @@ func CatchWithPool[T any]( }, pool) } -type tuple[T1, T2 any] struct { - _1 T1 - _2 T2 +// Finally executes a function regardless of whether the promise is fulfilled or rejected +func Finally[T any](p *Promise[T], ctx context.Context, fn func()) *Promise[T] { + return NewWithPool(func(resolve func(T), reject func(error)) { + result, err := p.Await(ctx) + fn() + if err != nil { + reject(err) + } else { + resolve(result) + } + }, defaultPool) +} + +// Timeout returns a new Promise that rejects if the original Promise doesn't resolve within the specified duration +func Timeout[T any](p *Promise[T], d time.Duration) *Promise[T] { + ctx, cancel := context.WithTimeout(context.Background(), d) + return NewWithPool(func(resolve func(T), reject func(error)) { + defer cancel() + result, err := p.Await(ctx) + if err != nil { + reject(err) + } else { + resolve(result) + } + }, defaultPool) } diff --git a/promise_test.go b/promise_test.go index 476cdfa..0fc9c05 100644 --- a/promise_test.go +++ b/promise_test.go @@ -361,3 +361,105 @@ func BenchmarkNewWithPool(b *testing.B) { }) } } + +func TestPromise_Race(t *testing.T) { + t.Run("RaceWithFastestResolve", func(t *testing.T) { + ctx := context.Background() + p1 := New(func(resolve func(string), reject func(error)) { + time.Sleep(100 * time.Millisecond) + resolve("slow") + }) + p2 := New(func(resolve func(string), reject func(error)) { + resolve("fast") + }) + p3 := New(func(resolve func(string), reject func(error)) { + time.Sleep(50 * time.Millisecond) + resolve("medium") + }) + + racePromise := Race(ctx, p1, p2, p3) + result, err := racePromise.Await(ctx) + require.NoError(t, err) + require.Equal(t, "fast", result) + }) + + t.Run("RaceWithFastestReject", func(t *testing.T) { + ctx := context.Background() + p1 := New(func(resolve func(string), reject func(error)) { + time.Sleep(100 * time.Millisecond) + resolve("slow") + }) + p2 := New(func(resolve func(string), reject func(error)) { + reject(errors.New("fast error")) + }) + p3 := New(func(resolve func(string), reject func(error)) { + time.Sleep(50 * time.Millisecond) + resolve("medium") + }) + + racePromise := Race(ctx, p1, p2, p3) + _, err := racePromise.Await(ctx) + require.Error(t, err) + require.Equal(t, "fast error", err.Error()) + }) +} + +func TestPromise_Finally(t *testing.T) { + t.Run("FinallyAfterResolve", func(t *testing.T) { + ctx := context.Background() + finallyExecuted := false + p := New(func(resolve func(string), reject func(error)) { + resolve("success") + }) + + finalPromise := Finally(p, ctx, func() { + finallyExecuted = true + }) + + result, err := finalPromise.Await(ctx) + require.NoError(t, err) + require.Equal(t, "success", result) + require.True(t, finallyExecuted) + }) + + t.Run("FinallyAfterReject", func(t *testing.T) { + ctx := context.Background() + finallyExecuted := false + p := New(func(resolve func(string), reject func(error)) { + reject(errors.New("error")) + }) + + finalPromise := Finally(p, ctx, func() { + finallyExecuted = true + }) + + _, err := finalPromise.Await(ctx) + require.Error(t, err) + require.True(t, finallyExecuted) + }) +} + +func TestPromise_Timeout(t *testing.T) { + t.Run("TimeoutBeforeResolve", func(t *testing.T) { + p := New(func(resolve func(string), reject func(error)) { + time.Sleep(200 * time.Millisecond) + resolve("too late") + }) + + timeoutPromise := Timeout(p, 100*time.Millisecond) + _, err := timeoutPromise.Await(context.Background()) + require.Error(t, err) + }) + + t.Run("ResolveBeforeTimeout", func(t *testing.T) { + p := New(func(resolve func(string), reject func(error)) { + time.Sleep(50 * time.Millisecond) + resolve("on time") + }) + + timeoutPromise := Timeout(p, 100*time.Millisecond) + result, err := timeoutPromise.Await(context.Background()) + require.NoError(t, err) + require.Equal(t, "on time", result) + }) +}