diff --git a/controllers/server.go b/controllers/server.go index 614a8135a..d6a10f547 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -15,24 +16,32 @@ import ( func serverHandlers(r *mux.Router) { // r.HandleFunc("/api/server/addnetwork/{network}", securityCheckServer(true, http.HandlerFunc(addNetwork))).Methods(http.MethodPost) - r.HandleFunc("/api/server/health", http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - resp.WriteHeader(http.StatusOK) - resp.Write([]byte("Server is up and running!!")) - })) - r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))).Methods(http.MethodGet) - r.HandleFunc("/api/server/getserverinfo", Authorize(true, false, "node", http.HandlerFunc(getServerInfo))).Methods(http.MethodGet) + r.HandleFunc( + "/api/server/health", + http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.WriteHeader(http.StatusOK) + resp.Write([]byte("Server is up and running!!")) + }), + ) + r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))). + Methods(http.MethodGet) + r.HandleFunc("/api/server/getserverinfo", Authorize(true, false, "node", http.HandlerFunc(getServerInfo))). + Methods(http.MethodGet) r.HandleFunc("/api/server/status", http.HandlerFunc(getStatus)).Methods(http.MethodGet) - r.HandleFunc("/api/server/usage", Authorize(true, false, "user", http.HandlerFunc(getUsage))).Methods(http.MethodGet) + r.HandleFunc("/api/server/usage", Authorize(true, false, "user", http.HandlerFunc(getUsage))). + Methods(http.MethodGet) } -func getUsage(w http.ResponseWriter, r *http.Request) { +func getUsage(w http.ResponseWriter, _ *http.Request) { type usage struct { - Hosts int `json:"hosts"` - Clients int `json:"clients"` - Networks int `json:"networks"` - Users int `json:"users"` - Ingresses int `json:"ingresses"` - Egresses int `json:"egresses"` + Hosts int `json:"hosts"` + Clients int `json:"clients"` + Networks int `json:"networks"` + Users int `json:"users"` + Ingresses int `json:"ingresses"` + Egresses int `json:"egresses"` + Relays int `json:"relays"` + InternetGateways int `json:"internet_gateways"` } var serverUsage usage hosts, err := logic.GetAllHosts() @@ -51,6 +60,7 @@ func getUsage(w http.ResponseWriter, r *http.Request) { if err == nil { serverUsage.Networks = len(networks) } + // TODO this part bellow can be optimized to get nodes just once ingresses, err := logic.GetAllIngresses() if err == nil { serverUsage.Ingresses = len(ingresses) @@ -59,12 +69,19 @@ func getUsage(w http.ResponseWriter, r *http.Request) { if err == nil { serverUsage.Egresses = len(egresses) } + relays, err := logic.GetRelays() + if err == nil { + serverUsage.Relays = len(relays) + } + gateways, err := logic.GetInternetGateways() + if err == nil { + serverUsage.InternetGateways = len(gateways) + } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(models.SuccessResponse{ Code: http.StatusOK, Response: serverUsage, }) - } // swagger:route GET /api/server/status server getStatus @@ -105,12 +122,12 @@ func getStatus(w http.ResponseWriter, r *http.Request) { // allowUsers - allow all authenticated (valid) users - only used by getConfig, may be able to remove during refactor func allowUsers(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - var errorResponse = models.ErrorResponse{ + errorResponse := models.ErrorResponse{ Code: http.StatusUnauthorized, Message: logic.Unauthorized_Msg, } bearerToken := r.Header.Get("Authorization") - var tokenSplit = strings.Split(bearerToken, " ") - var authToken = "" + tokenSplit := strings.Split(bearerToken, " ") + authToken := "" if len(tokenSplit) < 2 { logic.ReturnErrorResponse(w, r, errorResponse) return @@ -144,7 +161,7 @@ func getServerInfo(w http.ResponseWriter, r *http.Request) { // get params json.NewEncoder(w).Encode(servercfg.GetServerInfo()) - //w.WriteHeader(http.StatusOK) + // w.WriteHeader(http.StatusOK) } // swagger:route GET /api/server/getconfig server getConfig @@ -170,5 +187,5 @@ func getConfig(w http.ResponseWriter, r *http.Request) { scfg.IsPro = "yes" } json.NewEncoder(w).Encode(scfg) - //w.WriteHeader(http.StatusOK) + // w.WriteHeader(http.StatusOK) } diff --git a/logic/gateway.go b/logic/gateway.go index 2f7234698..342dacb30 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -11,7 +11,27 @@ import ( "github.com/gravitl/netmaker/servercfg" ) -// GetAllIngresses - gets all the hosts that are ingresses +// GetInternetGateways - gets all the nodes that are internet gateways +func GetInternetGateways() ([]models.Node, error) { + nodes, err := GetAllNodes() + if err != nil { + return nil, err + } + igs := make([]models.Node, 0) + for _, node := range nodes { + if !node.IsEgressGateway { + continue + } + for _, ran := range node.EgressGatewayRanges { + if ran == "0.0.0.0/0" { + igs = append(igs, node) + } + } + } + return igs, nil +} + +// GetAllIngresses - gets all the nodes that are ingresses func GetAllIngresses() ([]models.Node, error) { nodes, err := GetAllNodes() if err != nil { @@ -26,7 +46,7 @@ func GetAllIngresses() ([]models.Node, error) { return ingresses, nil } -// GetAllEgresses - gets all the hosts that are egresses +// GetAllEgresses - gets all the nodes that are egresses func GetAllEgresses() ([]models.Node, error) { nodes, err := GetAllNodes() if err != nil { diff --git a/logic/relay.go b/logic/relay.go index 394f3cf28..8c89b3d51 100644 --- a/logic/relay.go +++ b/logic/relay.go @@ -5,6 +5,10 @@ import ( "net" ) +var GetRelays = func() ([]models.Node, error) { + return []models.Node{}, nil +} + var RelayedAllowedIPs = func(peer, node *models.Node) []net.IPNet { return []net.IPNet{} } diff --git a/pro/initialize.go b/pro/initialize.go index 4b8abee36..258859fcf 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -52,6 +52,7 @@ func InitPro() { logic.GetMetrics = proLogic.GetMetrics logic.UpdateMetrics = proLogic.UpdateMetrics logic.DeleteMetrics = proLogic.DeleteMetrics + logic.GetRelays = proLogic.GetRelays logic.GetAllowedIpsForRelayed = proLogic.GetAllowedIpsForRelayed logic.RelayedAllowedIPs = proLogic.RelayedAllowedIPs logic.UpdateRelayed = proLogic.UpdateRelayed diff --git a/pro/license.go b/pro/license.go index e1adeff65..c76505348 100644 --- a/pro/license.go +++ b/pro/license.go @@ -9,17 +9,18 @@ import ( "encoding/json" "errors" "fmt" - "golang.org/x/exp/slog" "io" "net/http" "time" + "golang.org/x/crypto/nacl/box" + "golang.org/x/exp/slog" + "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" - "golang.org/x/crypto/nacl/box" ) const ( @@ -28,7 +29,7 @@ const ( type apiServerConf struct { PrivateKey []byte `json:"private_key" binding:"required"` - PublicKey []byte `json:"public_key" binding:"required"` + PublicKey []byte `json:"public_key" binding:"required"` } // AddLicenseHooks - adds the validation and cache clear hooks @@ -112,7 +113,11 @@ func ValidateLicense() (err error) { return err } - respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey) + respData, err := ncutils.BoxDecrypt( + base64decode(licenseResponse.EncryptedLicense), + apiPublicKey, + tempPrivKey, + ) if err != nil { err = fmt.Errorf("failed to decrypt license: %w", err) return err @@ -132,7 +137,7 @@ func ValidateLicense() (err error) { // as well as secure communication with API // if none present, it generates a new pair func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) { - var returnData = apiServerConf{} + returnData := apiServerConf{} currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key) if err != nil && !database.IsEmptyRecord(err) { return nil, nil, err @@ -181,7 +186,6 @@ func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) { } func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, error) { - publicKeyBytes, err := ncutils.ConvertKeyToBytes(publicKey) if err != nil { return nil, err @@ -198,7 +202,11 @@ func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, erro return nil, err } - req, err := http.NewRequest(http.MethodPost, getAccountsHost()+"/api/v1/license/validate", bytes.NewReader(requestBody)) + req, err := http.NewRequest( + http.MethodPost, + getAccountsHost()+"/api/v1/license/validate", + bytes.NewReader(requestBody), + ) if err != nil { return nil, err } @@ -241,7 +249,7 @@ func getAccountsHost() string { } func cacheResponse(response []byte) error { - var lrc = licenseResponseCache{ + lrc := licenseResponseCache{ Body: response, } diff --git a/pro/logic/relays.go b/pro/logic/relays.go index 95614dd6d..db201750c 100644 --- a/pro/logic/relays.go +++ b/pro/logic/relays.go @@ -12,6 +12,21 @@ import ( "net" ) +// GetRelays - gets all the nodes that are relays +func GetRelays() ([]models.Node, error) { + nodes, err := logic.GetAllNodes() + if err != nil { + return nil, err + } + relays := make([]models.Node, 0) + for _, node := range nodes { + if node.IsRelay { + relays = append(relays, node) + } + } + return relays, nil +} + // CreateRelay - creates a relay func CreateRelay(relay models.RelayRequest) ([]models.Node, models.Node, error) { var returnnodes []models.Node @@ -67,7 +82,7 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N return returnnodes } -//func GetRelayedNodes(relayNode *models.Node) (models.Node, error) { +// func GetRelayedNodes(relayNode *models.Node) (models.Node, error) { // var returnnodes []models.Node // networkNodes, err := GetNetworkNodes(relayNode.Network) // if err != nil { @@ -81,12 +96,12 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N // } // } // return returnnodes, nil -//} +// } // ValidateRelay - checks if relay is valid func ValidateRelay(relay models.RelayRequest) error { var err error - //isIp := functions.IsIpCIDR(gateway.RangeString) + // isIp := functions.IsIpCIDR(gateway.RangeString) empty := len(relay.RelayedNodes) == 0 if empty { return errors.New("IP Ranges Cannot Be Empty") diff --git a/pro/types.go b/pro/types.go index 97be71943..3e4ef08d1 100644 --- a/pro/types.go +++ b/pro/types.go @@ -54,13 +54,15 @@ type LicenseSecret struct { // Usage - struct for license usage type Usage struct { - Servers int `json:"servers"` - Users int `json:"users"` - Hosts int `json:"hosts"` - Clients int `json:"clients"` - Networks int `json:"networks"` - Ingresses int `json:"ingresses"` - Egresses int `json:"egresses"` + Servers int `json:"servers"` + Users int `json:"users"` + Hosts int `json:"hosts"` + Clients int `json:"clients"` + Networks int `json:"networks"` + Ingresses int `json:"ingresses"` + Egresses int `json:"egresses"` + Relays int `json:"relays"` + InternetGateways int `json:"internet_gateways"` } // Usage.SetDefaults - sets the default values for usage @@ -72,6 +74,8 @@ func (l *Usage) SetDefaults() { l.Networks = 0 l.Ingresses = 0 l.Egresses = 0 + l.Relays = 0 + l.InternetGateways = 0 } // ValidateLicenseRequest - used for request to validate license endpoint diff --git a/pro/util.go b/pro/util.go index f48418b17..7fc2d4cd9 100644 --- a/pro/util.go +++ b/pro/util.go @@ -16,9 +16,7 @@ func base64encode(input []byte) string { // base64decode - base64 decode helper function func base64decode(input string) []byte { - bytes, err := base64.StdEncoding.DecodeString(input) - if err != nil { return nil } @@ -44,6 +42,7 @@ func getCurrentServerUsage() (limits Usage) { if err == nil { limits.Networks = len(networks) } + // TODO this part bellow can be optimized to get nodes just once ingresses, err := logic.GetAllIngresses() if err == nil { limits.Ingresses = len(ingresses) @@ -52,5 +51,13 @@ func getCurrentServerUsage() (limits Usage) { if err == nil { limits.Egresses = len(egresses) } + relays, err := logic.GetRelays() + if err == nil { + limits.Relays = len(relays) + } + gateways, err := logic.GetInternetGateways() + if err == nil { + limits.InternetGateways = len(gateways) + } return }