Skip to content

Commit

Permalink
simplify lastUsage code
Browse files Browse the repository at this point in the history
  • Loading branch information
rekby committed Mar 24, 2024
1 parent 895d516 commit c28ddc7
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 38 deletions.
17 changes: 6 additions & 11 deletions internal/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"sync/atomic"
"time"

"github.com/jonboulle/clockwork"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
Expand Down Expand Up @@ -327,7 +326,7 @@ func (c *conn) Invoke(
return c.wrapError(err)
}

defer c.lastUsage.Lock()()
defer c.lastUsage.SharedLock()()

ctx, traceID, err := meta.TraceID(ctx)
if err != nil {
Expand Down Expand Up @@ -412,7 +411,7 @@ func (c *conn) NewStream(
return nil, c.wrapError(err)
}

defer c.lastUsage.Lock()()
defer c.lastUsage.SharedLock()()

ctx, traceID, err := meta.TraceID(ctx)
if err != nil {
Expand Down Expand Up @@ -487,15 +486,11 @@ func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, ca
}

func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn {
clock := clockwork.NewRealClock()
c := &conn{
endpoint: e,
config: config,
done: make(chan struct{}),
lastUsage: &lastUsage{
t: clock.Now(),
clock: clock,
},
endpoint: e,
config: config,
done: make(chan struct{}),
lastUsage: newLastUsage(nil),
}
c.state.Store(uint32(Created))
for _, opt := range opts {
Expand Down
6 changes: 3 additions & 3 deletions internal/conn/grpc_client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (s *grpcClientStream) CloseSend() (err error) {
onDone(err)
}()

defer s.c.lastUsage.Lock()()
defer s.c.lastUsage.SharedLock()()

err = s.ClientStream.CloseSend()

Expand Down Expand Up @@ -61,7 +61,7 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) {
onDone(err)
}()

defer s.c.lastUsage.Lock()()
defer s.c.lastUsage.SharedLock()()

err = s.ClientStream.SendMsg(m)

Expand Down Expand Up @@ -100,7 +100,7 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {
onDone(err)
}()

defer s.c.lastUsage.Lock()()
defer s.c.lastUsage.SharedLock()()

defer func() {
if err != nil {
Expand Down
34 changes: 18 additions & 16 deletions internal/conn/last_usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,42 @@ import (
"time"

"github.com/jonboulle/clockwork"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
)

type lastUsage struct {
locks atomic.Int64
mu xsync.RWMutex
t time.Time
t atomic.Pointer[time.Time]
clock clockwork.Clock
}

func (l *lastUsage) Get() time.Time {
if l.locks.CompareAndSwap(0, 1) {
defer func() {
l.locks.Add(-1)
}()
func newLastUsage(clock clockwork.Clock) *lastUsage {
if clock == nil {
clock = clockwork.NewRealClock()
}
now := clock.Now()
usage := &lastUsage{
clock: clock,
}
usage.t.Store(&now)

l.mu.RLock()
defer l.mu.RUnlock()
return usage
}

return l.t
func (l *lastUsage) Get() time.Time {
if l.locks.Load() == 0 {
return *l.t.Load()
}

return l.clock.Now()
}

func (l *lastUsage) Lock() (releaseFunc func()) {
func (l *lastUsage) SharedLock() (releaseFunc func()) {
l.locks.Add(1)

return sync.OnceFunc(func() {
if l.locks.Add(-1) == 0 {
l.mu.WithLock(func() {
l.t = l.clock.Now()
})
now := l.clock.Now()
l.t.Store(&now)
}
})
}
17 changes: 9 additions & 8 deletions internal/conn/last_usage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ func Test_lastUsage_Lock(t *testing.T) {
start := time.Unix(0, 0)
clock := clockwork.NewFakeClockAt(start)
lu := &lastUsage{
t: start,
clock: clock,
}
lu.t.Store(&start)
t1 := lu.Get()
require.Equal(t, start, t1)
f := lu.Lock()
f := lu.SharedLock()
clock.Advance(time.Hour)
t2 := lu.Get()
require.Equal(t, start.Add(time.Hour), t2)
Expand All @@ -34,19 +34,19 @@ func Test_lastUsage_Lock(t *testing.T) {
start := time.Unix(0, 0)
clock := clockwork.NewFakeClockAt(start)
lu := &lastUsage{
t: start,
clock: clock,
}
lu.t.Store(&start)
t1 := lu.Get()
require.Equal(t, start, t1)
f1 := lu.Lock()
f1 := lu.SharedLock()
clock.Advance(time.Hour)
t2 := lu.Get()
require.Equal(t, start.Add(time.Hour), t2)
f2 := lu.Lock()
f2 := lu.SharedLock()
clock.Advance(time.Hour)
f1()
f3 := lu.Lock()
f3 := lu.SharedLock()
clock.Advance(time.Hour)
t3 := lu.Get()
require.Equal(t, start.Add(3*time.Hour), t3)
Expand All @@ -72,17 +72,18 @@ func Test_lastUsage_Lock(t *testing.T) {
start := time.Unix(0, 0)
clock := clockwork.NewFakeClockAt(start)
lu := &lastUsage{
t: start,
clock: clock,
}
lu.t.Store(&start)

func() {
t1 := lu.Get()
require.Equal(t, start, t1)
clock.Advance(time.Hour)
t2 := lu.Get()
require.Equal(t, start, t2)
clock.Advance(time.Hour)
defer lu.Lock()()
defer lu.SharedLock()()
t3 := lu.Get()
require.Equal(t, start.Add(2*time.Hour), t3)
clock.Advance(time.Hour)
Expand Down

0 comments on commit c28ddc7

Please sign in to comment.