diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index a7347aeb2..c95870e69 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -479,17 +479,36 @@ func GetPublicIP() (string, error) { iplist = append([]string{publicIpService}, iplist...) } + // regular timeout + timeoutTarget := time.Duration(10) * time.Second + + // our unit test sets a bad ip service to test the timeout and fallback mechanics + // if that env variable is set, lower the timeout so the test doesn't take forever + testCheck := os.Getenv("NETMAKER_TEST_BAD_IP_SERVICE") + if testCheck != "" { + // we also don't really need to be checking all the services in the list after the planned timeout + // we just need to make sure that it falls back to the following service, which will be a mock + iplist = iplist[:1] + iplist = append(iplist, os.Getenv("NETMAKER_TEST_IP_SERVICE")) + } + for _, ipserver := range iplist { + if testCheck != "" && strings.EqualFold(testCheck, ipserver) { + timeoutTarget = time.Duration(2) * time.Second + } + client := &http.Client{ - Timeout: time.Second * 10, + Timeout: timeoutTarget, } - resp, err := client.Get(ipserver) + var resp *http.Response + resp, err = client.Get(ipserver) if err != nil { continue } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) + var bodyBytes []byte + bodyBytes, err = io.ReadAll(resp.Body) if err != nil { continue } @@ -619,7 +638,7 @@ func GetEmqxRestEndpoint() string { // IsBasicAuthEnabled - checks if basic auth has been configured to be turned off func IsBasicAuthEnabled() bool { - var enabled = true //default + var enabled = true // default if os.Getenv("BASIC_AUTH") != "" { enabled = os.Getenv("BASIC_AUTH") == "yes" } else if config.Config.Server.BasicAuth != "" { @@ -648,7 +667,7 @@ func GetNetmakerTenantID() string { // GetStunPort - Get the port to run the stun server on func GetStunPort() int { - port := 3478 //default + port := 3478 // default if os.Getenv("STUN_PORT") != "" { portInt, err := strconv.Atoi(os.Getenv("STUN_PORT")) if err == nil { @@ -662,7 +681,7 @@ func GetStunPort() int { // GetTurnPort - Get the port to run the turn server on func GetTurnPort() int { - port := 3479 //default + port := 3479 // default if os.Getenv("TURN_PORT") != "" { portInt, err := strconv.Atoi(os.Getenv("TURN_PORT")) if err == nil { diff --git a/servercfg/serverconf_test.go b/servercfg/serverconf_test.go new file mode 100644 index 000000000..2c8f69075 --- /dev/null +++ b/servercfg/serverconf_test.go @@ -0,0 +1,116 @@ +package servercfg + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/gravitl/netmaker/config" +) + +func TestGetPublicIP(t *testing.T) { + var testIP = "55.55.55.55" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte(testIP)); err != nil { + t.Errorf("expected no error, got %v", err) + } + })) + defer server.Close() + if err := os.Setenv("NETMAKER_TEST_IP_SERVICE", server.URL); err != nil { + t.Logf("WARNING: could not set NETMAKER_TEST_IP_SERVICE env var") + } + + t.Run("Use PUBLIC_IP_SERVICE if set", func(t *testing.T) { + + // set the environment variable + if err := os.Setenv("PUBLIC_IP_SERVICE", server.URL); err != nil { + t.Fatalf("expected no error, got %v", err) + } + defer func() { + _ = os.Unsetenv("PUBLIC_IP_SERVICE") + }() + + var ip string + var err error + if ip, err = GetPublicIP(); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !strings.EqualFold(ip, testIP) { + t.Errorf("expected IP to be %s, got %s", testIP, ip) + } + }) + + t.Run("Use config.Config.Server.PublicIPService if PUBLIC_IP_SERVICE isn't set", func(t *testing.T) { + config.Config.Server.PublicIPService = server.URL + defer func() { config.Config.Server.PublicIPService = "" }() + + ip, err := GetPublicIP() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if ip != testIP { + t.Fatalf("expected IP to be %s, got %s", testIP, ip) + } + }) + + t.Run("Handle service timeout", func(t *testing.T) { + if os.Getenv("NETMAKER_TEST_IP_SERVICE") == "" { + t.Skip("NETMAKER_TEST_IP_SERVICE not set") + } + + var badTestIP = "123.45.67.91" + badServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // intentionally delay to simulate a timeout + time.Sleep(3 * time.Second) + _, _ = w.Write([]byte(badTestIP)) + })) + defer badServer.Close() + + // we set this so that we can lower the timeout for the test to not hold up CI + if err := os.Setenv("NETMAKER_TEST_BAD_IP_SERVICE", badServer.URL); err != nil { + // but if we can't set it, we can't test it + t.Skip("failed to set NETMAKER_TEST_BAD_IP_SERVICE, skipping test because we won't timeout") + } + defer func() { + _ = os.Unsetenv("NETMAKER_TEST_BAD_IP_SERVICE") + }() + + // mock the config + oldConfig := config.Config.Server.PublicIPService + config.Config.Server.PublicIPService = badServer.URL + defer func() { config.Config.Server.PublicIPService = oldConfig }() + + res, err := GetPublicIP() + if err != nil { + t.Errorf("GetPublicIP() fallback has failed: %v", err) + } + if strings.EqualFold(res, badTestIP) { + t.Errorf("GetPublicIP() returned the response from the server that should have timed out: %v", res) + } + if !strings.EqualFold(res, testIP) { + t.Errorf("GetPublicIP() did not fallback to the correct IP: %v", res) + } + + t.Run("Assert error is passed down", func(t *testing.T) { + oldConfig = config.Config.Server.PublicIPService + oldEnv := os.Getenv("NETMAKER_TEST_IP_SERVICE") + + // make sure that even the fallback fails + if err = os.Setenv("NETMAKER_TEST_IP_SERVICE", badServer.URL); err != nil { + t.Skipf("could not set NETMAKER_TEST_IP_SERVICE env var") + } + defer func() { _ = os.Setenv("NETMAKER_TEST_IP_SERVICE", oldEnv) }() + + // https://github.com/golang/go/issues/63445 + // if _, err = GetPublicIP(); !errors.Is(err, context.DeadlineExceeded) { + if _, err = GetPublicIP(); err == nil || !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("expected error to be %v, got %v", context.DeadlineExceeded, err) + } + }) + }) + +}