diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index 5aa3406..866c15e 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -14,7 +14,12 @@ import ( ) func TestNewGateway(t *testing.T) { - srv, err := server.New(server.Config{}) + srv, err := server.New(server.Config{ + HTTPListenAddr: "localhost", + HTTPListenPort: 1234, + UnAuthorizedHTTPListenAddr: "localhost", + UnAuthorizedHTTPListenPort: 1235, + }) if err != nil { t.Fatal(err) } @@ -44,6 +49,7 @@ func TestNewGateway(t *testing.T) { } assert.NotNil(t, gw) + srv.Shutdown() } func TestStartGateway(t *testing.T) { @@ -316,7 +322,7 @@ func TestStartGateway(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - gw, err := createMockGateway("localhost", 8080, 8081, tc.config) + gw, err := createMockGateway("localhost", 8010, 8011, tc.config) if tc.expectedErr == nil && err != nil { t.Fatalf("Unexpected error when creating the gateway: %v\n", err) } diff --git a/server/server.go b/server/server.go index 36c75d2..7410c7a 100644 --- a/server/server.go +++ b/server/server.go @@ -18,9 +18,11 @@ import ( ) const ( - AUTH = "auth" - UNAUTH = "unauth" - DefaultNetwork = "tcp" + AUTH = "auth" + UNAUTH = "unauth" + DefaultNetwork = "tcp" + DefaultAuthPort = 80 + DefaultUnauthPort = 8081 ) type Config struct { @@ -58,7 +60,12 @@ type server struct { } func initAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, error) { - listenAddr := fmt.Sprintf("%s:%d", cfg.HTTPListenAddr, cfg.HTTPListenPort) + port, err := checkPort(cfg.HTTPListenAddr, cfg.HTTPListenPort, DefaultAuthPort, DefaultNetwork) + if err != nil { + return nil, err + } + cfg.HTTPListenPort = port + listenAddr := fmt.Sprintf("%s:%d", cfg.HTTPListenAddr, port) httpListener, err := net.Listen(DefaultNetwork, listenAddr) if err != nil { return nil, err @@ -115,6 +122,11 @@ func initAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, e } func initUnAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, error) { + port, err := checkPort(cfg.UnAuthorizedHTTPListenAddr, cfg.UnAuthorizedHTTPListenPort, DefaultUnauthPort, DefaultNetwork) + if err != nil { + return nil, err + } + cfg.UnAuthorizedHTTPListenPort = port listenAddr := fmt.Sprintf("%s:%d", cfg.UnAuthorizedHTTPListenAddr, cfg.UnAuthorizedHTTPListenPort) unauthHttpListener, err := net.Listen(DefaultNetwork, listenAddr) if err != nil { @@ -222,7 +234,7 @@ func New(cfg Config) (*Server, error) { } func (s *Server) Run() error { - logrus.Infof("the server has started listening on %v", s.authServer.httpServer.Addr) + logrus.Infof("the main server has started listening on %v", s.authServer.httpServer.Addr) errChan := make(chan error, 1) go func() { @@ -237,6 +249,7 @@ func (s *Server) Run() error { } }() + logrus.Infof("the admin server has started listening on %v", s.unAuthServer.httpServer.Addr) go func() { err := s.unAuthServer.run() if err == http.ErrServerClosed { @@ -292,3 +305,25 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) GetHTTPHandlers() (http.Handler, http.Handler) { return s.authServer.http, s.unAuthServer.http } + +func checkPortAvailable(addr string, port int, network string) bool { + l, err := net.Listen(network, fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + return false + } + l.Close() + return true +} + +func checkPort(addr string, port int, defaultPort int, network string) (int, error) { + p := port + if port == 0 { + logrus.Info(fmt.Sprintf("port not specified, trying default port %d", defaultPort)) + if checkPortAvailable(addr, defaultPort, network) { + p = defaultPort + } else { + return 0, fmt.Errorf(fmt.Sprintf("port %d is not available, please specify a port", defaultPort)) + } + } + return p, nil +} diff --git a/server/server_test.go b/server/server_test.go index b0d3547..0ad7570 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -25,6 +25,8 @@ func TestNew(t *testing.T) { config: Config{ HTTPListenAddr: "http://localhost", HTTPListenPort: 8080, + UnAuthorizedHTTPListenAddr: "localhost", + UnAuthorizedHTTPListenPort: 1111, ServerGracefulShutdownTimeout: time.Second * 5, HTTPServerReadTimeout: time.Second * 10, HTTPServerWriteTimeout: time.Second * 10, @@ -35,8 +37,10 @@ func TestNew(t *testing.T) { { name: "invalid address for unauth", config: Config{ - UnAuthorizedHTTPListenAddr: "http://localhost", + HTTPListenAddr: "localhost", HTTPListenPort: 8080, + UnAuthorizedHTTPListenAddr: "http://localhost", + UnAuthorizedHTTPListenPort: 8081, ServerGracefulShutdownTimeout: time.Second * 5, HTTPServerReadTimeout: time.Second * 10, HTTPServerWriteTimeout: time.Second * 10, @@ -49,6 +53,8 @@ func TestNew(t *testing.T) { config: Config{ HTTPListenAddr: "localhost", HTTPListenPort: 8080, + UnAuthorizedHTTPListenAddr: "localhost", + UnAuthorizedHTTPListenPort: 8081, ServerGracefulShutdownTimeout: time.Second * 5, HTTPServerReadTimeout: time.Second * 10, HTTPServerWriteTimeout: time.Second * 10, @@ -61,6 +67,8 @@ func TestNew(t *testing.T) { config: Config{ HTTPListenAddr: "localhost", HTTPListenPort: 8080, + UnAuthorizedHTTPListenAddr: "localhost", + UnAuthorizedHTTPListenPort: 8081, ServerGracefulShutdownTimeout: time.Second * 5, HTTPServerReadTimeout: time.Second * 10, HTTPServerWriteTimeout: time.Second * 10, @@ -94,6 +102,7 @@ func TestNew(t *testing.T) { t.Errorf("Expected server address to be %s:%d, but got %s", tc.config.HTTPListenAddr, tc.config.HTTPListenPort, server.authServer.httpServer.Addr) } } + server.Shutdown() }) } } @@ -115,7 +124,6 @@ func TestServer_RegisterTo(t *testing.T) { s.RegisterTo("/test_auth", testHandler, AUTH) s.RegisterTo("/test_unauth", testHandler, UNAUTH) - // Test authorized server. req := httptest.NewRequest(http.MethodGet, "/test_auth", nil) w := httptest.NewRecorder() @@ -126,7 +134,6 @@ func TestServer_RegisterTo(t *testing.T) { t.Errorf("Expected status code %d for AUTH server, but got %d", http.StatusOK, resp.StatusCode) } - // Test unauthorized server. req = httptest.NewRequest(http.MethodGet, "/test_unauth", nil) w = httptest.NewRecorder() @@ -227,6 +234,10 @@ func TestRun(t *testing.T) { func TestReadyHandler(t *testing.T) { cfg := Config{ + HTTPListenAddr: "localhost", + HTTPListenPort: 1234, + UnAuthorizedHTTPListenAddr: "localhost", + UnAuthorizedHTTPListenPort: 1235, HTTPServerReadTimeout: 5 * time.Second, HTTPServerWriteTimeout: 5 * time.Second, HTTPServerIdleTimeout: 5 * time.Second, @@ -278,4 +289,42 @@ func TestReadyHandler(t *testing.T) { } }) } + s.Shutdown() +} + +func TestCheckPortAvailable(t *testing.T) { + tests := []struct { + name string + listenAddr string + listenPort int + wantAvailable bool + }{ + { + name: "port available", + listenAddr: "localhost", + listenPort: 8080, + wantAvailable: true, + }, + { + name: "port unavailable", + listenAddr: "localhost", + listenPort: 1234, + wantAvailable: false, + }, + } + + listener, err := net.Listen(DefaultNetwork, fmt.Sprintf("%s:%d", "localhost", 1234)) + if err != nil { + t.Fatalf("Failed to create a listener: %v", err) + } + defer listener.Close() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + available := checkPortAvailable(tt.listenAddr, tt.listenPort, DefaultNetwork) + if available != tt.wantAvailable { + t.Errorf("Expected port %d to be available: %v", tt.listenPort, tt.wantAvailable) + } + }) + } }