Skip to content

Commit

Permalink
feat: add new features finally, race, timeout (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoanguyenkh authored Sep 22, 2024
1 parent 513e69f commit 16ff1a8
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 88 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import (
func main() {
ctx := context.Background()

p1 := promise4g.New(ctx, func(resolve func(string), reject func(error)) {
p1 := promise4g.New(func(resolve func(string), reject func(error)) {
time.Sleep(100 * time.Millisecond)
resolve("one")
})

p2 := promise4g.New(ctx, func(resolve func(string), reject func(error)) {
p2 := promise4g.New(func(resolve func(string), reject func(error)) {
time.Sleep(200 * time.Millisecond)
resolve("two")
})

p3 := promise4g.New(ctx, func(resolve func(string), reject func(error)) {
p3 := promise4g.New(func(resolve func(string), reject func(error)) {
time.Sleep(300 * time.Millisecond)
resolve("three")
})
Expand Down
184 changes: 99 additions & 85 deletions promise.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,33 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)

// Promise represents a computation that will eventually be completed with a value of type T or an error.
type Promise[T any] struct {
value T
err error
ch chan struct{}
value atomic.Value
err atomic.Value
done chan struct{}
once sync.Once
}

func New[T any](
task func(resolve func(T), reject func(error))) *Promise[T] {
// New creates a new Promise with the given task
func New[T any](task func(resolve func(T), reject func(error))) *Promise[T] {
return NewWithPool(task, defaultPool)
}

func NewWithPool[T any](
task func(resolve func(T), reject func(error)),
pool Pool) *Promise[T] {
// NewWithPool creates a new Promise with the given task and pool
func NewWithPool[T any](task func(resolve func(T), reject func(error)), pool Pool) *Promise[T] {
if task == nil {
panic("task must not be nil")
}
if pool == nil {
panic("pool must not be nil")
}
var t T
p := &Promise[T]{
value: t,
err: nil,
ch: make(chan struct{}),
once: sync.Once{},
done: make(chan struct{}),
}
pool.Go(func() {
defer p.handlePanic()
Expand All @@ -42,103 +39,105 @@ func NewWithPool[T any](
return p
}

// Await waits for the Promise to be resolved or rejected
func (p *Promise[T]) Await(ctx context.Context) (T, error) {
select {
case <-ctx.Done():
var t T
return t, ctx.Err()
case <-p.ch:
return p.value, p.err
case <-p.done:
if err := p.err.Load(); err != nil {
var t T
return t, err.(error)
}
return p.value.Load().(T), nil
}
}

func (p *Promise[T]) resolve(value T) {
p.once.Do(func() {
p.value = value
close(p.ch)
p.value.Store(value)
close(p.done)
})
}

func (p *Promise[T]) reject(err error) {
p.once.Do(func() {
p.err = err
close(p.ch)
p.err.Store(err)
close(p.done)
})
}

func (p *Promise[T]) handlePanic() {
err := recover()
if err == nil {
return
}
switch v := err.(type) {
case error:
p.reject(v)
default:
p.reject(fmt.Errorf("%+v", v))
if r := recover(); r != nil {
var err error
switch v := r.(type) {
case error:
err = v
default:
err = fmt.Errorf("%v", v)
}
p.reject(err)
}
}

func All[T any](
ctx context.Context,
promises ...*Promise[T],
) *Promise[[]T] {
// All waits for all promises to be resolved
func All[T any](ctx context.Context, promises ...*Promise[T]) *Promise[[]T] {
return AllWithPool(ctx, defaultPool, promises...)
}

func AllWithPool[T any](
ctx context.Context,
pool Pool,
promises ...*Promise[T],
) *Promise[[]T] {
// AllWithPool waits for all promises to be resolved using the given pool
func AllWithPool[T any](ctx context.Context, pool Pool, promises ...*Promise[T]) *Promise[[]T] {
if len(promises) == 0 {
panic("missing promises")
}

return NewWithPool(func(resolve func([]T), reject func(error)) {
resultsChan := make(chan tuple[T, int], len(promises))
errsChan := make(chan error, len(promises))

for idx, p := range promises {
idx := idx
_ = ThenWithPool(p, ctx, func(data T) (T, error) {
resultsChan <- tuple[T, int]{_1: data, _2: idx}
return data, nil
}, pool)
_ = CatchWithPool(p, ctx, func(err error) error {
errsChan <- err
return err
}, pool)
}

results := make([]T, len(promises))
for idx := 0; idx < len(promises); idx++ {
select {
case result := <-resultsChan:
results[result._2] = result._1
case err := <-errsChan:
reject(err)
return
}
var wg sync.WaitGroup
wg.Add(len(promises))

for i, p := range promises {
i, p := i, p
pool.Go(func() {
defer wg.Done()
result, err := p.Await(ctx)
if err != nil {
reject(err)
return
}
results[i] = result
})
}

wg.Wait()
resolve(results)
}, pool)
}

func Then[A, B any](
p *Promise[A],
ctx context.Context,
resolve func(A) (B, error),
) *Promise[B] {
// Race returns a promise that resolves or rejects as soon as one of the promises resolves or rejects
func Race[T any](ctx context.Context, promises ...*Promise[T]) *Promise[T] {
return NewWithPool(func(resolve func(T), reject func(error)) {
for _, p := range promises {
go func(p *Promise[T]) {
result, err := p.Await(ctx)
if err != nil {
reject(err)
} else {
resolve(result)
}
}(p)
}
}, defaultPool)
}

// Then chains a new Promise to the current one
func Then[A, B any](p *Promise[A], ctx context.Context, resolve func(A) (B, error)) *Promise[B] {
return ThenWithPool(p, ctx, resolve, defaultPool)
}

func ThenWithPool[A, B any](
p *Promise[A],
ctx context.Context,
resolve func(A) (B, error),
pool Pool,
) *Promise[B] {
// ThenWithPool chains a new Promise to the current one using the given pool
func ThenWithPool[A, B any](p *Promise[A], ctx context.Context, resolve func(A) (B, error), pool Pool) *Promise[B] {
return NewWithPool(func(resolveB func(B), reject func(error)) {
result, err := p.Await(ctx)
if err != nil {
Expand All @@ -156,20 +155,13 @@ func ThenWithPool[A, B any](
}, pool)
}

func Catch[T any](
p *Promise[T],
ctx context.Context,
reject func(err error) error,
) *Promise[T] {
// Catch handles errors in the Promise chain
func Catch[T any](p *Promise[T], ctx context.Context, reject func(error) error) *Promise[T] {
return CatchWithPool(p, ctx, reject, defaultPool)
}

func CatchWithPool[T any](
p *Promise[T],
ctx context.Context,
reject func(err error) error,
pool Pool,
) *Promise[T] {
// CatchWithPool handles errors in the Promise chain using the given pool
func CatchWithPool[T any](p *Promise[T], ctx context.Context, reject func(error) error, pool Pool) *Promise[T] {
return NewWithPool(func(resolve func(T), internalReject func(error)) {
result, err := p.Await(ctx)
if err != nil {
Expand All @@ -180,7 +172,29 @@ func CatchWithPool[T any](
}, pool)
}

type tuple[T1, T2 any] struct {
_1 T1
_2 T2
// Finally executes a function regardless of whether the promise is fulfilled or rejected
func Finally[T any](p *Promise[T], ctx context.Context, fn func()) *Promise[T] {
return NewWithPool(func(resolve func(T), reject func(error)) {
result, err := p.Await(ctx)
fn()
if err != nil {
reject(err)
} else {
resolve(result)
}
}, defaultPool)
}

// Timeout returns a new Promise that rejects if the original Promise doesn't resolve within the specified duration
func Timeout[T any](p *Promise[T], d time.Duration) *Promise[T] {
ctx, cancel := context.WithTimeout(context.Background(), d)
return NewWithPool(func(resolve func(T), reject func(error)) {
defer cancel()
result, err := p.Await(ctx)
if err != nil {
reject(err)
} else {
resolve(result)
}
}, defaultPool)
}
102 changes: 102 additions & 0 deletions promise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,105 @@ func BenchmarkNewWithPool(b *testing.B) {
})
}
}

func TestPromise_Race(t *testing.T) {
t.Run("RaceWithFastestResolve", func(t *testing.T) {
ctx := context.Background()
p1 := New(func(resolve func(string), reject func(error)) {
time.Sleep(100 * time.Millisecond)
resolve("slow")
})
p2 := New(func(resolve func(string), reject func(error)) {
resolve("fast")
})
p3 := New(func(resolve func(string), reject func(error)) {
time.Sleep(50 * time.Millisecond)
resolve("medium")
})

racePromise := Race(ctx, p1, p2, p3)
result, err := racePromise.Await(ctx)
require.NoError(t, err)
require.Equal(t, "fast", result)
})

t.Run("RaceWithFastestReject", func(t *testing.T) {
ctx := context.Background()
p1 := New(func(resolve func(string), reject func(error)) {
time.Sleep(100 * time.Millisecond)
resolve("slow")
})
p2 := New(func(resolve func(string), reject func(error)) {
reject(errors.New("fast error"))
})
p3 := New(func(resolve func(string), reject func(error)) {
time.Sleep(50 * time.Millisecond)
resolve("medium")
})

racePromise := Race(ctx, p1, p2, p3)
_, err := racePromise.Await(ctx)
require.Error(t, err)
require.Equal(t, "fast error", err.Error())
})
}

func TestPromise_Finally(t *testing.T) {
t.Run("FinallyAfterResolve", func(t *testing.T) {
ctx := context.Background()
finallyExecuted := false
p := New(func(resolve func(string), reject func(error)) {
resolve("success")
})

finalPromise := Finally(p, ctx, func() {
finallyExecuted = true
})

result, err := finalPromise.Await(ctx)
require.NoError(t, err)
require.Equal(t, "success", result)
require.True(t, finallyExecuted)
})

t.Run("FinallyAfterReject", func(t *testing.T) {
ctx := context.Background()
finallyExecuted := false
p := New(func(resolve func(string), reject func(error)) {
reject(errors.New("error"))
})

finalPromise := Finally(p, ctx, func() {
finallyExecuted = true
})

_, err := finalPromise.Await(ctx)
require.Error(t, err)
require.True(t, finallyExecuted)
})
}

func TestPromise_Timeout(t *testing.T) {
t.Run("TimeoutBeforeResolve", func(t *testing.T) {
p := New(func(resolve func(string), reject func(error)) {
time.Sleep(200 * time.Millisecond)
resolve("too late")
})

timeoutPromise := Timeout(p, 100*time.Millisecond)
_, err := timeoutPromise.Await(context.Background())
require.Error(t, err)
})

t.Run("ResolveBeforeTimeout", func(t *testing.T) {
p := New(func(resolve func(string), reject func(error)) {
time.Sleep(50 * time.Millisecond)
resolve("on time")
})

timeoutPromise := Timeout(p, 100*time.Millisecond)
result, err := timeoutPromise.Await(context.Background())
require.NoError(t, err)
require.Equal(t, "on time", result)
})
}

0 comments on commit 16ff1a8

Please sign in to comment.