From c2188f017d58248383b47e2f5ad4921f9a516458 Mon Sep 17 00:00:00 2001 From: Rueian Date: Mon, 11 Nov 2024 23:29:35 -0800 Subject: [PATCH] feat: shorten rueidislock validity if there is a shorter deadline in the context Signed-off-by: Rueian --- rueidislock/lock.go | 28 ++++++++++++++-------------- rueidislock/lock_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/rueidislock/lock.go b/rueidislock/lock.go index 8b869cf6..597a07c0 100644 --- a/rueidislock/lock.go +++ b/rueidislock/lock.go @@ -151,18 +151,18 @@ func keyname(prefix, name string, i int32) string { return sb.String() } -func (m *locker) acquire(ctx context.Context, key, val string, deadline time.Time, force bool) (err error) { +func (m *locker) acquire(ctx context.Context, key, val string, duration time.Duration, deadline time.Time, force bool) (err error) { ctx, cancel := context.WithTimeout(ctx, m.timeout) var resp rueidis.RedisResult if force { if m.setpx { - resp = fcqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)}) + resp = fcqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(duration.Milliseconds(), 10)}) } else { resp = fcqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)}) } } else { if m.setpx { - resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)}) + resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(duration.Milliseconds(), 10)}) } else { resp = acqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)}) } @@ -249,8 +249,15 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string var err error val := random() - deadline := time.Now().Add(m.validity) - cacneltm := time.AfterFunc(m.validity, cancel) + now := time.Now() + duration := m.validity + if dl, ok := ctx.Deadline(); ok { + if dur := dl.Sub(now); dur < duration { + duration = dur + } + } + deadline := now.Add(duration) + cacneltm := time.AfterFunc(duration, cancel) released := int32(0) acquired := int32(0) failures := int32(0) @@ -266,19 +273,12 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string deadline = deadline.Add(m.interval) if err = m.script(ctx, extend, key, val, deadline); err == nil { timer.Reset(m.interval) - if !m.noloop { - <-csc - } } case _, ok := <-csc: if !ok { err = ErrLockerClosed } else { - if err = m.script(ctx, extend, key, val, deadline); err == nil { - if !m.noloop { - <-csc - } - } + err = m.script(ctx, extend, key, val, deadline) } } } @@ -314,7 +314,7 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string default: } if !errors.Is(err, ErrNotLocked) { - if err = m.acquire(ctx, key, val, deadline, force); force && err == nil { + if err = m.acquire(ctx, key, val, duration, deadline, force); force && err == nil { m.mu.RLock() if m.gates != nil { select { diff --git a/rueidislock/lock_test.go b/rueidislock/lock_test.go index df473b67..e47152e4 100644 --- a/rueidislock/lock_test.go +++ b/rueidislock/lock_test.go @@ -446,6 +446,39 @@ func TestLocker_WithContext_CancelContext(t *testing.T) { } } +func TestLocker_WithContext_ShorterTimeoutContext(t *testing.T) { + test := func(t *testing.T, noLoop, setpx, nocsc bool) { + locker := newLocker(t, noLoop, setpx, nocsc) + locker.validity = time.Second * 5 + locker.interval = time.Second * 3 + defer locker.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + ctx, cancel, err := locker.WithContext(ctx, strconv.Itoa(rand.Int())) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Second * 2) + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Fatalf("unexpected context canceled %v", ctx.Err()) + } + cancel() + } + for _, nocsc := range []bool{false, true} { + t.Run("Tracking Loop", func(t *testing.T) { + test(t, false, false, nocsc) + }) + t.Run("Tracking NoLoop", func(t *testing.T) { + test(t, true, false, nocsc) + }) + t.Run("SET PX", func(t *testing.T) { + test(t, true, true, nocsc) + }) + } +} + func TestLocker_TryWithContext(t *testing.T) { test := func(t *testing.T, noLoop, setpx, nocsc bool) { locker := newLocker(t, noLoop, setpx, nocsc)