diff --git a/promise_test.go b/promise_test.go index 0fc9c05..c261fc7 100644 --- a/promise_test.go +++ b/promise_test.go @@ -4,12 +4,12 @@ import ( "context" "errors" "fmt" + "math/rand" "testing" "time" "github.com/panjf2000/ants/v2" conc "github.com/sourcegraph/conc/pool" - "github.com/stretchr/testify/require" ) @@ -105,6 +105,54 @@ func TestPromise_All(t *testing.T) { require.Nil(t, result) }) + t.Run("AllContainRejectAndPanic", func(t *testing.T) { + ctx := context.Background() + p1 := New(func(resolve func(string), reject func(error)) { + resolve("one") + }) + + p2 := New(func(resolve func(string), reject func(error)) { + panic(errors.New("panic")) + }) + + p3 := New(func(resolve func(string), reject func(error)) { + reject(errors.New("error")) + }) + p := All(ctx, p1, p2, p3) + result, err := p.Await(ctx) + require.Error(t, err) + require.Nil(t, result) + }) + + t.Run("AllContainRejectWithDelay", func(t *testing.T) { + ctx := context.Background() + p1 := New(func(resolve func(string), reject func(error)) { + time.Sleep(200 * time.Millisecond) + resolve("one") + }) + + p2 := New(func(resolve func(string), reject func(error)) { + time.Sleep(300 * time.Millisecond) + reject(errors.New("error2")) + }) + + p3 := New(func(resolve func(string), reject func(error)) { + time.Sleep(50 * time.Millisecond) + reject(errors.New("error3")) + }) + + start := time.Now() + p := All(ctx, p1, p2, p3) + result, err := p.Await(ctx) + elapsed := time.Since(start) + + require.Error(t, err) + require.Nil(t, result) + require.GreaterOrEqual(t, elapsed, 50*time.Millisecond, "Promise rejected too quickly") + fmt.Println("total time", elapsed) + require.Less(t, elapsed, 100*time.Millisecond, "Promise did not reject in expected time") + }) + t.Run("AllWithCanceledContext", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -290,7 +338,6 @@ func TestNewWithPool(t *testing.T) { } func TestCheckAllConcurrent(t *testing.T) { - ctx := context.Background() start := time.Now() @@ -463,3 +510,97 @@ func TestPromise_Timeout(t *testing.T) { require.Equal(t, "on time", result) }) } + +func TestPromise_MultipleInOrder(t *testing.T) { + ctx := context.Background() + start := time.Now() + numPromises := 20 + promises := make([]*Promise[int], numPromises) + + for i := 0; i < numPromises; i++ { + delay := time.Duration(rand.Intn(100)) * time.Millisecond + value := i + promises[i] = New(func(resolve func(int), reject func(error)) { + time.Sleep(delay) + resolve(value) + }) + } + + p := All(ctx, promises...) + results, err := p.Await(ctx) + elapsed := time.Since(start) + require.Less(t, elapsed, 110*time.Millisecond) + require.NoError(t, err) + require.Equal(t, numPromises, len(results)) + + for i, result := range results { + require.Equal(t, i, result, "Results should be in order") + } +} + +func TestPromise_AllHTTPCalls(t *testing.T) { + ctx := context.Background() + + type httpResponse1 struct { + Message string `json:"message"` + RequestId string `json:"requestId"` + } + + type httpResponse2 struct { + Username string `json:"username"` + RequestId string `json:"requestId"` + } + + fakeHttp1 := func() (httpResponse1, error) { + time.Sleep(100 * time.Millisecond) + return httpResponse1{ + Message: "hello world", + RequestId: "requestId 1", + }, nil + } + + fakeHttp2 := func() (httpResponse2, error) { + time.Sleep(200 * time.Millisecond) + return httpResponse2{ + Username: "username", + RequestId: "requestId 2", + }, nil + } + + p1 := New(func(resolve func(any), reject func(error)) { + resp1, err := fakeHttp1() + if err != nil { + reject(err) + } else { + resolve(resp1) + } + }) + + p2 := New(func(resolve func(any), reject func(error)) { + resp2, err := fakeHttp2() + if err != nil { + reject(err) + } else { + resolve(resp2) + } + }) + + start := time.Now() + p := All(ctx, p1, p2) + results, err := p.Await(ctx) + elapsed := time.Since(start) + + require.NoError(t, err, "Promise.All should not return an error") + require.Equal(t, 2, len(results), "Should have results for both calls") + require.Less(t, elapsed, 210*time.Millisecond, "Calls should be concurrent") + + res1, ok := results[0].(httpResponse1) + require.True(t, ok, "First result should be of type httpResponse1") + require.Equal(t, "hello world", res1.Message) + require.Equal(t, "requestId 1", res1.RequestId) + + res2, ok := results[1].(httpResponse2) + require.True(t, ok, "Second result should be of type httpResponse2") + require.Equal(t, "username", res2.Username) + require.Equal(t, "requestId 2", res2.RequestId) +}