From 5b28d4693beb302404ca802c262d0dd80f37ffa7 Mon Sep 17 00:00:00 2001 From: zengl Date: Sun, 29 Sep 2024 14:28:59 +0800 Subject: [PATCH] Server ack message propagates AC tokens to agent. Http knock api also set agent's cookie with tokens --- ac/httpac.go | 7 ++--- ac/msghandler.go | 20 ++++++------- ac/tokenstore.go | 31 ++++++++------------ ac/udpac.go | 2 +- common/nhpmsg.go | 17 +++++------ common/types.go | 8 ++++++ server/constants.go | 5 ++-- server/httpserver.go | 52 ++++++++++++++++++++++++++++------ server/plugins/example/main.go | 33 +++++++++++++++------ server/udpserver.go | 35 ++++++++++++----------- 10 files changed, 133 insertions(+), 77 deletions(-) diff --git a/ac/httpac.go b/ac/httpac.go index 70b7409a..f929ed2e 100644 --- a/ac/httpac.go +++ b/ac/httpac.go @@ -134,9 +134,8 @@ func (hs *HttpAC) IsRunning() bool { func (ha *HttpAC) initRouter() { g := ha.ginEngine - pluginGrp := g.Group("refresh") - // display login page with templates - pluginGrp.GET("/:token", func(ctx *gin.Context) { + refreshGrp := g.Group("refresh") + refreshGrp.GET("/:token", func(ctx *gin.Context) { var err error token := ctx.Param("token") log.Info("get refresh request. aspId: %s, query: %v", token, ctx.Request.URL.RawQuery) @@ -198,7 +197,7 @@ func (ha *HttpAC) HandleHttpRefreshOperations(c *gin.Context, req *common.HttpRe entry.SrcAddrs = append(entry.SrcAddrs, newSrcAddr) } - _, err = ha.ua.HandleAccessControl(entry.AgentUser, entry.SrcAddrs, entry.DstAddrs, entry.OpenTime, nil) + _, err = ha.ua.HandleAccessControl(entry.User, entry.SrcAddrs, entry.DstAddrs, entry.OpenTime, nil) if err != nil { c.String(http.StatusOK, "{\"errMsg\": \"%s\"}", err) return diff --git a/ac/msghandler.go b/ac/msghandler.go index 30e58c62..fd8ad425 100644 --- a/ac/msghandler.go +++ b/ac/msghandler.go @@ -41,7 +41,7 @@ func (a *UdpAC) HandleUdpACOperations(ppd *core.PacketParserData) (err error) { srcAddrs := dopMsg.SourceAddrs dstAddrs := dopMsg.DestinationAddrs openTimeSec := int(dopMsg.OpenTime) - agentUser := &AgentUser{ + agentUser := &common.AgentUser{ UserId: dopMsg.UserId, DeviceId: dopMsg.DeviceId, OrganizationId: dopMsg.OrganizationId, @@ -54,10 +54,10 @@ func (a *UdpAC) HandleUdpACOperations(ppd *core.PacketParserData) (err error) { // generate ac token and save user and access information entry := &AccessEntry{ - AgentUser: agentUser, - SrcAddrs: srcAddrs, - DstAddrs: dstAddrs, - OpenTime: openTimeSec, + User: agentUser, + SrcAddrs: srcAddrs, + DstAddrs: dstAddrs, + OpenTime: openTimeSec, } artMsg.ACToken = a.GenerateAccessToken(entry) @@ -84,7 +84,7 @@ func (a *UdpAC) HandleUdpACOperations(ppd *core.PacketParserData) (err error) { return err } -func (a *UdpAC) HandleAccessControl(au *AgentUser, srcAddrs []*common.NetAddress, dstAddrs []*common.NetAddress, openTimeSec int, artMsgIn *common.ACOpsResultMsg) (artMsg *common.ACOpsResultMsg, err error) { +func (a *UdpAC) HandleAccessControl(au *common.AgentUser, srcAddrs []*common.NetAddress, dstAddrs []*common.NetAddress, openTimeSec int, artMsgIn *common.ACOpsResultMsg) (artMsg *common.ACOpsResultMsg, err error) { if artMsgIn == nil { artMsg = &common.ACOpsResultMsg{} } else { @@ -345,10 +345,10 @@ func (a *UdpAC) HandleAccessControl(au *AgentUser, srcAddrs []*common.NetAddress log.Info("[HandleAccessControl] open temporary udp port on %s", tladdr.String()) tempEntry := &AccessEntry{ - AgentUser: au, - SrcAddrs: srcAddrs, - DstAddrs: dstAddrs, - OpenTime: tempOpenTimeSec, + User: au, + SrcAddrs: srcAddrs, + DstAddrs: dstAddrs, + OpenTime: tempOpenTimeSec, } artMsg.PreAccessAction = &common.PreAccessInfo{ AccessPort: strconv.Itoa(pickedPort), diff --git a/ac/tokenstore.go b/ac/tokenstore.go index 79e61514..2f966218 100644 --- a/ac/tokenstore.go +++ b/ac/tokenstore.go @@ -11,23 +11,16 @@ import ( "github.com/OpenNHP/opennhp/log" ) -type AgentUser struct { - UserId string - DeviceId string - OrganizationId string - AuthServiceId string -} - type AccessEntry struct { - AgentUser *AgentUser + User *common.AgentUser SrcAddrs []*common.NetAddress DstAddrs []*common.NetAddress OpenTime int ExpireTime time.Time } -type TokenAccessMap = map[string]*AccessEntry // access token mapped into user and access information -type TokenStore = map[string]TokenAccessMap // upper layer of tokens, indexed by first two characters +type TokenToAccessMap = map[string]*AccessEntry // access token mapped into user and access information +type TokenStore = map[string]TokenToAccessMap // upper layer of tokens, indexed by first two characters func (a *UdpAC) GenerateAccessToken(entry *AccessEntry) string { var tsBytes [8]byte @@ -35,21 +28,21 @@ func (a *UdpAC) GenerateAccessToken(entry *AccessEntry) string { hash := sm3.New() binary.BigEndian.PutUint64(tsBytes[:], uint64(currTime)) - au := entry.AgentUser + au := entry.User hash.Write([]byte(a.config.ACId + au.UserId + au.DeviceId + au.OrganizationId + au.AuthServiceId)) hash.Write(tsBytes[:]) token := base64.StdEncoding.EncodeToString(hash.Sum(nil)) hash.Reset() - a.TokenStoreMutex.Lock() - defer a.TokenStoreMutex.Unlock() + a.tokenStoreMutex.Lock() + defer a.tokenStoreMutex.Unlock() entry.ExpireTime = time.Now().Add(time.Duration(entry.OpenTime) * time.Second) tokenMap, found := a.tokenStore[token[0:1]] if found { tokenMap[token] = entry } else { - tokenMap := make(TokenAccessMap) + tokenMap := make(TokenToAccessMap) tokenMap[token] = entry a.tokenStore[token[0:1]] = tokenMap } @@ -58,8 +51,8 @@ func (a *UdpAC) GenerateAccessToken(entry *AccessEntry) string { } func (a *UdpAC) VerifyAccessToken(token string) *AccessEntry { - a.TokenStoreMutex.Lock() - defer a.TokenStoreMutex.Unlock() + a.tokenStoreMutex.Lock() + defer a.tokenStoreMutex.Unlock() tokenMap, found := a.tokenStore[token[0:1]] if found { @@ -84,14 +77,14 @@ func (a *UdpAC) tokenStoreRefreshRoutine() { return case <-time.After(TokenStoreRefreshInterval * time.Second): - a.TokenStoreMutex.Lock() - defer a.TokenStoreMutex.Unlock() + a.tokenStoreMutex.Lock() + defer a.tokenStoreMutex.Unlock() now := time.Now() for head, tokenMap := range a.tokenStore { for token, entry := range tokenMap { if now.After(entry.ExpireTime) { - log.Info("[TokenStore] token %s expired", token) + log.Info("[TokenStore] token %s expired, remove", token) delete(tokenMap, token) } } diff --git a/ac/udpac.go b/ac/udpac.go index c5706a15..affbbc83 100644 --- a/ac/udpac.go +++ b/ac/udpac.go @@ -40,7 +40,7 @@ type UdpAC struct { serverPeerMutex sync.Mutex serverPeerMap map[string]*core.UdpPeer // indexed by server's public key - TokenStoreMutex sync.Mutex + tokenStoreMutex sync.Mutex tokenStore TokenStore device *core.Device diff --git a/common/nhpmsg.go b/common/nhpmsg.go index c7a2672a..dd1bbc4b 100644 --- a/common/nhpmsg.go +++ b/common/nhpmsg.go @@ -70,14 +70,15 @@ type PreAccessInfo struct { } type ServerKnockAckMsg struct { - ErrCode string `json:"errCode"` - ErrMsg string `json:"errMsg,omitempty"` - ResourceHost map[string]string `json:"resHost"` - OpenTime uint32 `json:"opnTime"` - AuthProviderToken string `json:"aspToken,omitempty"` // optional for ac backend validation - AgentAddr string `json:"agentAddr"` - PreAccessActions []*PreAccessInfo `json:"preActs,omitempty"` // optional for pre-access - RedirectUrl string `json:"redirectUrl,omitempty"` + ErrCode string `json:"errCode"` + ErrMsg string `json:"errMsg,omitempty"` + ResourceHost map[string]string `json:"resHost"` + OpenTime uint32 `json:"opnTime"` + AuthProviderToken string `json:"aspToken,omitempty"` // optional for ac backend validation + AgentAddr string `json:"agentAddr"` + ACTokens map[string]string `json:"acTokens"` + PreAccessActions map[string]*PreAccessInfo `json:"preActions,omitempty"` // optional for pre-access + RedirectUrl string `json:"redirectUrl,omitempty"` } type AgentListMsg struct { diff --git a/common/types.go b/common/types.go index 144cd4b7..a74f5b35 100644 --- a/common/types.go +++ b/common/types.go @@ -2,6 +2,14 @@ package common import "net/url" +// an object contains represent knocking user information +type AgentUser struct { + UserId string + DeviceId string + OrganizationId string + AuthServiceId string +} + // authsvcprovider and resource type LoginPageContext struct { Title string `json:"title,omitempty"` diff --git a/server/constants.go b/server/constants.go index f1eb6a5d..0b0b1e37 100644 --- a/server/constants.go +++ b/server/constants.go @@ -18,6 +18,7 @@ const ( // knock const ( - DefaultIpOpenTime = 120 // second, align with ipset default timeout - ACOpenCompensationTime = 5 // second + DefaultIpOpenTime = 120 // second, align with ipset default timeout + ACOpenCompensationTime = 5 // second + TokenStoreRefreshInterval = 10 // second ) diff --git a/server/httpserver.go b/server/httpserver.go index 3ca15ec2..9ff63f9a 100644 --- a/server/httpserver.go +++ b/server/httpserver.go @@ -253,6 +253,30 @@ func (hs *HttpServer) initRouter() { } hs.legacyAuthWithAspPlugin(ctx, req) }) + + /* + refreshGrp := g.Group("refresh") + refreshGrp.GET("/:token", func(ctx *gin.Context) { + var err error + token := ctx.Param("token") + log.Info("get refresh request. aspId: %s, query: %v", token, ctx.Request.URL.RawQuery) + + if len(token) == 0 { + err = common.ErrUrlPathInvalid + log.Error("path error: %v", err) + ctx.String(http.StatusOK, "{\"errMsg\": \"path error: %v\"}", err) + return + } + + req := &common.HttpRefreshRequest{ + Token: token, + SrcIp: ctx.Query("srcip"), + } + + hs.handleRefreshResource() + }) + */ + } // corsMiddleware is a middleware function that adds CORS headers to the HTTP response. @@ -316,13 +340,13 @@ func (hs *HttpServer) handleHttpOpenResource(req *common.HttpKnockRequest, res * srcAddr := &common.NetAddress{Ip: srcIp} acDstIpMap := make(map[string][]*common.NetAddress) - for _, info := range res.Resources { - addrs, exist := acDstIpMap[info.ACId] + for resName, info := range res.Resources { + addrs, exist := acDstIpMap[resName] if exist { addrs = append(addrs, info.Addr) - acDstIpMap[info.ACId] = addrs + acDstIpMap[resName] = addrs } else { - acDstIpMap[info.ACId] = []*common.NetAddress{info.Addr} + acDstIpMap[resName] = []*common.NetAddress{info.Addr} } } @@ -330,8 +354,11 @@ func (hs *HttpServer) handleHttpOpenResource(req *common.HttpKnockRequest, res * var acWg sync.WaitGroup var artMsgsMutex sync.Mutex artMsgs := make(map[string]*common.ACOpsResultMsg) + ackMsg.ACTokens = make(map[string]string) + ackMsg.PreAccessActions = make(map[string]*common.PreAccessInfo) - for acId, addrs := range acDstIpMap { + for resName, addrs := range acDstIpMap { + acId := res.Resources[resName].ACId s.acConnectionMapMutex.Lock() acConn, found := s.acConnectionMap[acId] s.acConnectionMapMutex.Unlock() @@ -344,14 +371,16 @@ func (hs *HttpServer) handleHttpOpenResource(req *common.HttpKnockRequest, res * } acWg.Add(1) - go func(acip string, dstAddrs []*common.NetAddress) { + go func(name string, dstAddrs []*common.NetAddress) { defer acWg.Done() artMsg, _ := s.processACOperation(knkMsg, acConn, srcAddr, dstAddrs, res.OpenTime) artMsgsMutex.Lock() - artMsgs[acip] = artMsg + artMsgs[name] = artMsg + ackMsg.ACTokens[name] = artMsg.ACToken + ackMsg.PreAccessActions[name] = artMsg.PreAccessAction artMsgsMutex.Unlock() - }(acId, addrs) + }(resName, addrs) } acWg.Wait() @@ -359,7 +388,7 @@ func (hs *HttpServer) handleHttpOpenResource(req *common.HttpKnockRequest, res * for _, artMsg := range artMsgs { if artMsg.ErrCode != common.ErrSuccess.ErrorCode() { errCount++ - break + continue } } @@ -393,3 +422,8 @@ func (hs *HttpServer) NewHttpServerHelper() *plugins.HttpServerPluginHelper { func (hs *HttpServer) FindPluginHandler(aspId string) plugins.PluginHandler { return hs.udpServer.FindPluginHandler(aspId) } + +func (hs *HttpServer) handleRefreshResource(token string) (err error) { + // to do + return nil +} diff --git a/server/plugins/example/main.go b/server/plugins/example/main.go index e6431ffc..8fba9fbc 100644 --- a/server/plugins/example/main.go +++ b/server/plugins/example/main.go @@ -244,14 +244,31 @@ func authRegular(ctx *gin.Context, req *common.HttpKnockRequest, res *common.Res log.Error("RedirectUrl is not provided.") } else { ackMsg.RedirectUrl = res.RedirectUrl - ctx.SetCookie( - "nhp-token", // Name - "example-nhp-token-GUBdoVXpxt", // Value - -1, // MaxAge - "/", // Path - res.CookieDomain, // Domain - true, // Secure - true) // HttpOnly + } + + // set cookies + singleHost := len(ackMsg.ACTokens) == 1 + for resName, token := range ackMsg.ACTokens { + if singleHost { + ctx.SetCookie( + "nhp-token", // Name + token, // Value + -1, // MaxAge + "/", // Path + res.CookieDomain, // Domain + true, // Secure + true) // HttpOnly + } else { + domain := strings.Split(ackMsg.ResourceHost[resName], ":")[0] + ctx.SetCookie( + "nhp-token"+"/"+resName, // Name + token, // Value + -1, // MaxAge + "/", // Path + domain, // Domain + true, // Secure + true) // HttpOnly + } log.Info("ctx.SetCookie.") } } diff --git a/server/udpserver.go b/server/udpserver.go index e322a352..ccb5e1f3 100644 --- a/server/udpserver.go +++ b/server/udpserver.go @@ -55,6 +55,9 @@ type UdpServer struct { acPeerMapMutex sync.Mutex acPeerMap map[string]*core.UdpPeer // indexed by peer's public key base64 string + tokenStoreMutex sync.Mutex + tokenStore TokenStore + // block address management blockAddrMapMutex sync.Mutex blockAddrMap map[string]*BlockAddr // indexed by remote UDP address, need lock for dynamic change @@ -189,6 +192,7 @@ func (s *UdpServer) Start(dirPath string, logLevel int) (err error) { s.remoteConnectionMap = make(map[string]*UdpConn) s.acConnectionMap = make(map[string]*ACConn) + s.tokenStore = make(TokenStore) s.blockAddrMap = make(map[string]*BlockAddr) s.signals.stop = make(chan struct{}) @@ -199,7 +203,8 @@ func (s *UdpServer) Start(dirPath string, logLevel int) (err error) { s.device.Start() // start server routines - s.wg.Add(4) + s.wg.Add(5) + go s.tokenStoreRefreshRoutine() go s.BlockAddrRefreshRoutine() go s.recvPacketRoutine() go s.sendMessageRoutine() @@ -855,13 +860,13 @@ func (s *UdpServer) handleNhpOpenResource(req *common.NhpAuthRequest, res *commo ackMsg = req.Ack acDstIpMap := make(map[string][]*common.NetAddress) - for _, info := range res.Resources { - addrs, exist := acDstIpMap[info.ACId] + for resName, info := range res.Resources { + addrs, exist := acDstIpMap[resName] if exist { addrs = append(addrs, info.Addr) - acDstIpMap[info.ACId] = addrs + acDstIpMap[resName] = addrs } else { - acDstIpMap[info.ACId] = []*common.NetAddress{info.Addr} + acDstIpMap[resName] = []*common.NetAddress{info.Addr} } } @@ -869,8 +874,11 @@ func (s *UdpServer) handleNhpOpenResource(req *common.NhpAuthRequest, res *commo var acWg sync.WaitGroup var artMsgsMutex sync.Mutex artMsgs := make(map[string]*common.ACOpsResultMsg) + ackMsg.ACTokens = make(map[string]string) + ackMsg.PreAccessActions = make(map[string]*common.PreAccessInfo) - for acId, dstAddrs := range acDstIpMap { + for resName, dstAddrs := range acDstIpMap { + acId := res.Resources[resName].ACId s.acConnectionMapMutex.Lock() acConn, found := s.acConnectionMap[acId] s.acConnectionMapMutex.Unlock() @@ -883,7 +891,7 @@ func (s *UdpServer) handleNhpOpenResource(req *common.NhpAuthRequest, res *commo } acWg.Add(1) - go func(id string, addrs []*common.NetAddress) { + go func(name string, addrs []*common.NetAddress) { defer acWg.Done() openTime := res.OpenTime @@ -892,9 +900,11 @@ func (s *UdpServer) handleNhpOpenResource(req *common.NhpAuthRequest, res *commo } artMsg, _ := s.processACOperation(knkMsg, acConn, srcAddr, addrs, openTime) artMsgsMutex.Lock() - artMsgs[id] = artMsg + artMsgs[name] = artMsg + ackMsg.ACTokens[name] = artMsg.ACToken + ackMsg.PreAccessActions[name] = artMsg.PreAccessAction artMsgsMutex.Unlock() - }(acId, dstAddrs) + }(resName, dstAddrs) } acWg.Wait() @@ -914,13 +924,6 @@ func (s *UdpServer) handleNhpOpenResource(req *common.NhpAuthRequest, res *commo return } - ackMsg.PreAccessActions = make([]*common.PreAccessInfo, 0, len(artMsgs)) - for _, artMsg := range artMsgs { - if artMsg.PreAccessAction != nil { - ackMsg.PreAccessActions = append(ackMsg.PreAccessActions, artMsg.PreAccessAction) - } - } - ackMsg.ErrCode = common.ErrSuccess.ErrorCode() ackMsg.ErrMsg = common.ErrSuccess.Error() return ackMsg, nil