Skip to content

Commit

Permalink
Rueidislock CSC re-register (#669)
Browse files Browse the repository at this point in the history
* fix: always re-register rueidislock csc notifications

Signed-off-by: Rueian <[email protected]>

* feat: shorten rueidislock validity if there is a shorter deadline in the context

Signed-off-by: Rueian <[email protected]>

---------

Signed-off-by: Rueian <[email protected]>
  • Loading branch information
rueian authored Nov 13, 2024
1 parent 83cc4f5 commit 57ef0cb
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 99 deletions.
152 changes: 53 additions & 99 deletions rueidislock/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,18 @@ func keyname(prefix, name string, i int32) string {
return sb.String()
}

func (m *locker) acquire(ctx context.Context, key, val string, deadline time.Time, force bool) (err error) {
func (m *locker) acquire(ctx context.Context, key, val string, duration time.Duration, deadline time.Time, force bool) (err error) {
ctx, cancel := context.WithTimeout(ctx, m.timeout)
var resp rueidis.RedisResult
if force {
if m.setpx {
resp = fcqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)})
resp = fcqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(duration.Milliseconds(), 10)})
} else {
resp = fcqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)})
}
} else {
if m.setpx {
resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)})
resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(duration.Milliseconds(), 10)})
} else {
resp = acqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)})
}
Expand All @@ -184,62 +184,16 @@ func (m *locker) script(ctx context.Context, script *rueidis.Lua, key, val strin
return ErrNotLocked
}

func (m *locker) waitgate(ctx context.Context, name string) (g *gate, err error) {
func (m *locker) getgate(name string) (g *gate) {
m.mu.Lock()
g, ok := m.gates[name]
if !ok {
if m.gates == nil {
m.mu.Unlock()
return nil, ErrLockerClosed
defer m.mu.Unlock()
if m.gates != nil {
if g = m.gates[name]; g == nil {
g = makegate(m.totalcnt)
m.gates[name] = g
}
g = makegate(m.totalcnt)
g.w++
m.gates[name] = g
m.mu.Unlock()
return g, nil
} else {
g.w++
m.mu.Unlock()
}
var timeout <-chan time.Time
if m.nocsc {
timeout = time.After(m.timeout)
}
select {
case <-ctx.Done():
m.removegate(g, name)
return nil, ctx.Err()
case _, ok = <-g.ch:
if ok {
return g, nil
}
return nil, ErrLockerClosed
case <-timeout:
return g, nil
}
}

func (m *locker) trygate(name string) (g *gate) {
m.mu.Lock()
if _, ok := m.gates[name]; !ok && m.gates != nil {
g = makegate(m.totalcnt)
g.w++
m.gates[name] = g
}
m.mu.Unlock()
return g
}

func (m *locker) forcegate(name string) (g *gate) {
m.mu.Lock()
if g = m.gates[name]; g == nil && m.gates != nil {
g = makegate(m.totalcnt)
m.gates[name] = g
}
if g != nil {
g.w++
}
m.mu.Unlock()
return g
}

Expand Down Expand Up @@ -295,8 +249,15 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
var err error

val := random()
deadline := time.Now().Add(m.validity)
cacneltm := time.AfterFunc(m.validity, cancel)
now := time.Now()
duration := m.validity
if dl, ok := ctx.Deadline(); ok {
if dur := dl.Sub(now); dur < duration {
duration = dur
}
}
deadline := now.Add(duration)
cacneltm := time.AfterFunc(duration, cancel)
released := int32(0)
acquired := int32(0)
failures := int32(0)
Expand All @@ -312,19 +273,12 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
deadline = deadline.Add(m.interval)
if err = m.script(ctx, extend, key, val, deadline); err == nil {
timer.Reset(m.interval)
if !m.noloop {
<-csc
}
}
case _, ok := <-csc:
if !ok {
err = ErrLockerClosed
} else {
if err = m.script(ctx, extend, key, val, deadline); err == nil {
if !m.noloop {
<-csc
}
}
err = m.script(ctx, extend, key, val, deadline)
}
}
}
Expand Down Expand Up @@ -360,7 +314,7 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
default:
}
if !errors.Is(err, ErrNotLocked) {
if err = m.acquire(ctx, key, val, deadline, force); force && err == nil {
if err = m.acquire(ctx, key, val, duration, deadline, force); force && err == nil {
m.mu.RLock()
if m.gates != nil {
select {
Expand Down Expand Up @@ -403,53 +357,53 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
return cancel, err
}

func (m *locker) ForceWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) {
var err error
func (m *locker) tryonce(ctx context.Context, name string, force bool) (context.Context, context.CancelFunc, error) {
g := m.getgate(name)
if g == nil {
return nil, nil, ErrLockerClosed
}
ctx, cancel := context.WithCancel(ctx)
if g := m.forcegate(name); g != nil {
if cancel, err = m.try(ctx, cancel, name, g, true); err == nil {
return ctx, cancel, nil
}
cancel, err := m.try(ctx, cancel, name, g, force)
if err != nil {
m.removegate(g, name)
}
cancel()
if err == nil {
err = ErrLockerClosed
cancel()
}
return ctx, cancel, err
}

func (m *locker) ForceWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) {
return m.tryonce(ctx, name, true)
}

func (m *locker) TryWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) {
var err error
ctx, cancel := context.WithCancel(ctx)
if g := m.trygate(name); g != nil {
if cancel, err = m.try(ctx, cancel, name, g, false); err == nil {
return ctx, cancel, nil
}
m.removegate(g, name)
}
cancel()
if err == nil {
err = fmt.Errorf("%w: the lock is held by others or the locker is closed", ErrNotLocked)
}
return ctx, cancel, err
return m.tryonce(ctx, name, false)
}

func (m *locker) WithContext(src context.Context, name string) (context.Context, context.CancelFunc, error) {
for {
g := m.getgate(name)
if g == nil {
return nil, nil, ErrLockerClosed
}
ctx, cancel := context.WithCancel(src)
g, err := m.waitgate(ctx, name)
if g != nil {
if cancel, err := m.try(ctx, cancel, name, g, false); err == nil {
return ctx, cancel, nil
}
m.mu.Lock()
g.w-- // do not delete g from m.gates here.
m.mu.Unlock()
if cancel, err := m.try(ctx, cancel, name, g, false); err == nil {
return ctx, cancel, nil
}
if cancel(); err != nil {
return ctx, cancel, err
cancel()
var timeout <-chan time.Time
if m.nocsc {
timeout = time.After(m.timeout)
}
select {
case <-src.Done():
m.removegate(g, name)
return nil, nil, src.Err()
case <-g.ch:
case <-timeout:
}
m.mu.Lock()
g.w--
m.mu.Unlock()
}
}

Expand Down
82 changes: 82 additions & 0 deletions rueidislock/lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,55 @@ func TestLocker_WithContext_ExtendByClientSideCaching(t *testing.T) {
})
}

func TestLocker_WithContext_AutoExtendConcurrent(t *testing.T) {
test := func(t *testing.T, noLoop, setpx, nocsc bool) {
locker := newLocker(t, noLoop, setpx, nocsc)
locker.validity = time.Second
locker.interval = time.Second / 2
defer locker.Close()

key := strconv.Itoa(rand.Int())

ctx1, cancel1, err1 := locker.WithContext(context.Background(), key)
if err1 != nil {
t.Fatal(err1)
}
go func() {
for i := 0; i < 4; i++ {
select {
case <-ctx1.Done():
t.Errorf("unexpected context canceled %v", ctx1.Err())
default:
time.Sleep(locker.validity)
}
}
cancel1()
}()
ctx2, cancel2, err2 := locker.WithContext(context.Background(), key)
if err2 != nil {
t.Fatal(err2)
}
if !errors.Is(ctx1.Err(), context.Canceled) {
t.Fatalf("unexpected context canceled %v", ctx1.Err())
}
if ctx2.Err() != nil {
t.Fatalf("unexpected context canceled %v", ctx2.Err())
}
cancel2()
}
for _, nocsc := range []bool{false, true} {
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false, nocsc)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false, nocsc)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true, nocsc)
})
}
}

func TestLocker_WithContext_AutoExtend(t *testing.T) {
test := func(t *testing.T, noLoop, setpx, nocsc bool) {
locker := newLocker(t, noLoop, setpx, nocsc)
Expand Down Expand Up @@ -446,6 +495,39 @@ func TestLocker_WithContext_CancelContext(t *testing.T) {
}
}

func TestLocker_WithContext_ShorterTimeoutContext(t *testing.T) {
test := func(t *testing.T, noLoop, setpx, nocsc bool) {
locker := newLocker(t, noLoop, setpx, nocsc)
locker.validity = time.Second * 5
locker.interval = time.Second * 3
defer locker.Close()

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

ctx, cancel, err := locker.WithContext(ctx, strconv.Itoa(rand.Int()))
if err != nil {
t.Fatal(err)
}
time.Sleep(time.Second * 2)
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Fatalf("unexpected context canceled %v", ctx.Err())
}
cancel()
}
for _, nocsc := range []bool{false, true} {
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false, nocsc)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false, nocsc)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true, nocsc)
})
}
}

func TestLocker_TryWithContext(t *testing.T) {
test := func(t *testing.T, noLoop, setpx, nocsc bool) {
locker := newLocker(t, noLoop, setpx, nocsc)
Expand Down

0 comments on commit 57ef0cb

Please sign in to comment.