From 7ab69c942f37a5eaac805bed062424cc2f4d5a1a Mon Sep 17 00:00:00 2001 From: artem_danilov Date: Tue, 7 Jan 2025 16:57:58 -0800 Subject: [PATCH] add tests to verify that circuit breaker is set for GetRegion calls Signed-off-by: artem_danilov --- internal/mockstore/mocktikv/pd.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/internal/mockstore/mocktikv/pd.go b/internal/mockstore/mocktikv/pd.go index c449ea4fc..5791c624f 100644 --- a/internal/mockstore/mocktikv/pd.go +++ b/internal/mockstore/mocktikv/pd.go @@ -55,7 +55,8 @@ import ( "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/pkg/caller" - sd "github.com/tikv/pd/client/servicediscovery" + "github.com/tikv/pd/client/pkg/circuitbreaker" +sd "github.com/tikv/pd/client/servicediscovery" "go.uber.org/atomic" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -226,6 +227,7 @@ func (m *mockTSFuture) Wait() (int64, int64, error) { } func (c *pdClient) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*router.Region, error) { + enforceCircuitBreakerFor("GetRegion", ctx) region, peer, buckets, downPeers := c.cluster.GetRegionByKey(key) if len(opts) == 0 { buckets = nil @@ -244,6 +246,7 @@ func (c *pdClient) GetRegionFromMember(ctx context.Context, key []byte, memberUR } func (c *pdClient) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*router.Region, error) { + enforceCircuitBreakerFor("GetPrevRegion", ctx) region, peer, buckets, downPeers := c.cluster.GetPrevRegionByKey(key) if len(opts) == 0 { buckets = nil @@ -252,6 +255,7 @@ func (c *pdClient) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.Ge } func (c *pdClient) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt.GetRegionOption) (*router.Region, error) { + enforceCircuitBreakerFor("GetRegionByID", ctx) region, peer, buckets, downPeers := c.cluster.GetRegionByID(regionID) return &router.Region{Meta: region, Leader: peer, Buckets: buckets, DownPeers: downPeers}, nil } @@ -465,3 +469,9 @@ func (m *pdClient) LoadResourceGroups(ctx context.Context) ([]*rmpb.ResourceGrou func (m *pdClient) GetServiceDiscovery() sd.ServiceDiscovery { return nil } func (m *pdClient) WithCallerComponent(caller.Component) pd.Client { return m } + +func enforceCircuitBreakerFor(name string, ctx context.Context) { + if circuitbreaker.FromContext(ctx) == nil { + panic(fmt.Errorf("CircuitBreaker must be configured for %s", name)) + } +}