From f94ee2f4ae59701ea6d335fee77afe759d0e54e4 Mon Sep 17 00:00:00 2001 From: Guo-shiuan Wang Date: Tue, 20 Feb 2024 22:35:36 +0000 Subject: [PATCH] Implement RemoveNextHop. --- dataplane/saiserver/attrmgr/attrmgr.go | 25 +++++++++++-- dataplane/saiserver/routing.go | 15 ++++++++ dataplane/saiserver/routing_test.go | 49 ++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 2 deletions(-) diff --git a/dataplane/saiserver/attrmgr/attrmgr.go b/dataplane/saiserver/attrmgr/attrmgr.go index c50c7a1a..3e0600d1 100644 --- a/dataplane/saiserver/attrmgr/attrmgr.go +++ b/dataplane/saiserver/attrmgr/attrmgr.go @@ -50,6 +50,16 @@ type AttrMgr struct { msgEnumToFieldNum map[string]map[int32]int } +func deleteOID(mgr *AttrMgr, oid string) error { + mgr.mu.Lock() + defer mgr.mu.Unlock() + if _, ok := mgr.attrs[oid]; !ok { + return fmt.Errorf("OID not found: %s", oid) + } + delete(mgr.attrs, oid) + return nil +} + // New returns a new AttrMgr. func New() *AttrMgr { mgr := &AttrMgr{ @@ -106,17 +116,28 @@ func (mgr *AttrMgr) Interceptor(ctx context.Context, req any, info *grpc.UnarySe if err != nil { return resp, err } - if strings.Contains(info.FullMethod, "Create") || strings.Contains(info.FullMethod, "Set") { + + switch { + case strings.Contains(info.FullMethod, "Create") || strings.Contains(info.FullMethod, "Set"): id, err := mgr.getID(reqMsg, respMsg) if err != nil { log.Warningf("failed to get id %v", err) return respMsg, nil } mgr.storeAttributes(id, reqMsg) - } else if strings.Contains(info.FullMethod, "Get") && strings.Contains(info.FullMethod, "Attribute") { + case strings.Contains(info.FullMethod, "Get") && strings.Contains(info.FullMethod, "Attribute"): if err := mgr.PopulateAttributes(reqMsg, respMsg); err != nil { return nil, err } + case strings.Contains(info.FullMethod, "Remove"): + id, err := mgr.getID(reqMsg, respMsg) + if err != nil { + log.Warningf("failed to get id %v", err) + return respMsg, nil + } + if err := deleteOID(mgr, id); err != nil { + return nil, err + } } return respMsg, nil } diff --git a/dataplane/saiserver/routing.go b/dataplane/saiserver/routing.go index 8de35338..6881c1e3 100644 --- a/dataplane/saiserver/routing.go +++ b/dataplane/saiserver/routing.go @@ -249,6 +249,21 @@ func (nh *nextHop) CreateNextHop(ctx context.Context, req *saipb.CreateNextHopRe }, nil } +func (nh *nextHop) RemoveNextHop(ctx context.Context, r *saipb.RemoveNextHopRequest) (*saipb.RemoveNextHopResponse, error) { + entry := fwdconfig.EntryDesc(fwdconfig.ExactEntry( + fwdconfig.PacketFieldBytes(fwdpb.PacketFieldNum_PACKET_FIELD_NUM_NEXT_HOP_ID).WithUint64(r.GetOid()))).Build() + nhReq := &fwdpb.TableEntryRemoveRequest{ + ContextId: &fwdpb.ContextId{Id: nh.dataplane.ID()}, + TableId: &fwdpb.TableId{ObjectId: &fwdpb.ObjectId{Id: NHTable}}, + EntryDesc: entry, + } + + if _, err := nh.dataplane.TableEntryRemove(ctx, nhReq); err != nil { + return nil, err + } + return &saipb.RemoveNextHopResponse{}, nil +} + func (nh *nextHop) CreateNextHops(ctx context.Context, r *saipb.CreateNextHopsRequest) (*saipb.CreateNextHopsResponse, error) { resp := &saipb.CreateNextHopsResponse{} for _, req := range r.GetReqs() { diff --git a/dataplane/saiserver/routing_test.go b/dataplane/saiserver/routing_test.go index 12580fe9..a42b8a58 100644 --- a/dataplane/saiserver/routing_test.go +++ b/dataplane/saiserver/routing_test.go @@ -323,6 +323,55 @@ func TestCreateNextHop(t *testing.T) { } } +func TestRemoveNextHop(t *testing.T) { + // Add, remove, and then check. + tests := []struct { + desc string + reqCreate *saipb.CreateNextHopRequest + oid uint64 // specify this if you want an arbitrary OID to remove. + wantErr string + }{ + { + desc: "pass", + reqCreate: &saipb.CreateNextHopRequest{ + Type: saipb.NextHopType_NEXT_HOP_TYPE_IP.Enum(), + RouterInterfaceId: proto.Uint64(10), + Ip: []byte{127, 0, 0, 1}, + }, + }, + { + desc: "fail", + reqCreate: &saipb.CreateNextHopRequest{ + Type: saipb.NextHopType_NEXT_HOP_TYPE_IP.Enum(), + RouterInterfaceId: proto.Uint64(10), + Ip: []byte{127, 0, 0, 1}, + }, + oid: 15, // a non-existing OID. + wantErr: "OID not found: 15", + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + dplane := &fakeSwitchDataplane{} + c, _, stopFn := newTestNextHop(t, dplane) + defer stopFn() + + resp, err := c.CreateNextHop(context.TODO(), tt.reqCreate) + if err != nil { + t.Fatalf("Unexpcted error: %v", err) + } + theOid := tt.oid + if theOid == 0 { + theOid = resp.Oid + } + _, gotErr := c.RemoveNextHop(context.TODO(), &saipb.RemoveNextHopRequest{Oid: theOid}) + if diff := errdiff.Check(gotErr, tt.wantErr); diff != "" { + t.Fatalf("RemoveNextHop() unexpected err: %s", diff) + } + }) + } +} + func TestCreateRouteEntry(t *testing.T) { tests := []struct { desc string