Skip to content

Commit

Permalink
Check the provided port in the configuration (#31)
Browse files Browse the repository at this point in the history
* Check port number and print diagnostic message when the app is running

Signed-off-by: Doğukan Teber <[email protected]>

* Write unit tests and update the existing unit test cases

Signed-off-by: Doğukan Teber <[email protected]>

* Update tests

Signed-off-by: Doğukan Teber <[email protected]>

* Change default server port

Signed-off-by: Doğukan Teber <[email protected]>

* Update tests

Signed-off-by: Doğukan Teber <[email protected]>

---------

Signed-off-by: Doğukan Teber <[email protected]>
  • Loading branch information
dogukanteber committed Jun 1, 2023
1 parent 22696db commit 9841cb4
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 10 deletions.
10 changes: 8 additions & 2 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ import (
)

func TestNewGateway(t *testing.T) {
srv, err := server.New(server.Config{})
srv, err := server.New(server.Config{
HTTPListenAddr: "localhost",
HTTPListenPort: 1234,
UnAuthorizedHTTPListenAddr: "localhost",
UnAuthorizedHTTPListenPort: 1235,
})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -44,6 +49,7 @@ func TestNewGateway(t *testing.T) {
}

assert.NotNil(t, gw)
srv.Shutdown()
}

func TestStartGateway(t *testing.T) {
Expand Down Expand Up @@ -316,7 +322,7 @@ func TestStartGateway(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
gw, err := createMockGateway("localhost", 8080, 8081, tc.config)
gw, err := createMockGateway("localhost", 8010, 8011, tc.config)
if tc.expectedErr == nil && err != nil {
t.Fatalf("Unexpected error when creating the gateway: %v\n", err)
}
Expand Down
45 changes: 40 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
)

const (
AUTH = "auth"
UNAUTH = "unauth"
DefaultNetwork = "tcp"
AUTH = "auth"
UNAUTH = "unauth"
DefaultNetwork = "tcp"
DefaultAuthPort = 80
DefaultUnauthPort = 8081
)

type Config struct {
Expand Down Expand Up @@ -58,7 +60,12 @@ type server struct {
}

func initAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, error) {
listenAddr := fmt.Sprintf("%s:%d", cfg.HTTPListenAddr, cfg.HTTPListenPort)
port, err := checkPort(cfg.HTTPListenAddr, cfg.HTTPListenPort, DefaultAuthPort, DefaultNetwork)
if err != nil {
return nil, err
}
cfg.HTTPListenPort = port
listenAddr := fmt.Sprintf("%s:%d", cfg.HTTPListenAddr, port)
httpListener, err := net.Listen(DefaultNetwork, listenAddr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -115,6 +122,11 @@ func initAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, e
}

func initUnAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, error) {
port, err := checkPort(cfg.UnAuthorizedHTTPListenAddr, cfg.UnAuthorizedHTTPListenPort, DefaultUnauthPort, DefaultNetwork)
if err != nil {
return nil, err
}
cfg.UnAuthorizedHTTPListenPort = port
listenAddr := fmt.Sprintf("%s:%d", cfg.UnAuthorizedHTTPListenAddr, cfg.UnAuthorizedHTTPListenPort)
unauthHttpListener, err := net.Listen(DefaultNetwork, listenAddr)
if err != nil {
Expand Down Expand Up @@ -222,7 +234,7 @@ func New(cfg Config) (*Server, error) {
}

func (s *Server) Run() error {
logrus.Infof("the server has started listening on %v", s.authServer.httpServer.Addr)
logrus.Infof("the main server has started listening on %v", s.authServer.httpServer.Addr)
errChan := make(chan error, 1)

go func() {
Expand All @@ -237,6 +249,7 @@ func (s *Server) Run() error {
}
}()

logrus.Infof("the admin server has started listening on %v", s.unAuthServer.httpServer.Addr)
go func() {
err := s.unAuthServer.run()
if err == http.ErrServerClosed {
Expand Down Expand Up @@ -292,3 +305,25 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
func (s *Server) GetHTTPHandlers() (http.Handler, http.Handler) {
return s.authServer.http, s.unAuthServer.http
}

func checkPortAvailable(addr string, port int, network string) bool {
l, err := net.Listen(network, fmt.Sprintf("%s:%d", addr, port))
if err != nil {
return false
}
l.Close()
return true
}

func checkPort(addr string, port int, defaultPort int, network string) (int, error) {
p := port
if port == 0 {
logrus.Info(fmt.Sprintf("port not specified, trying default port %d", defaultPort))
if checkPortAvailable(addr, defaultPort, network) {
p = defaultPort
} else {
return 0, fmt.Errorf(fmt.Sprintf("port %d is not available, please specify a port", defaultPort))
}
}
return p, nil
}
55 changes: 52 additions & 3 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func TestNew(t *testing.T) {
config: Config{
HTTPListenAddr: "http://localhost",
HTTPListenPort: 8080,
UnAuthorizedHTTPListenAddr: "localhost",
UnAuthorizedHTTPListenPort: 1111,
ServerGracefulShutdownTimeout: time.Second * 5,
HTTPServerReadTimeout: time.Second * 10,
HTTPServerWriteTimeout: time.Second * 10,
Expand All @@ -35,8 +37,10 @@ func TestNew(t *testing.T) {
{
name: "invalid address for unauth",
config: Config{
UnAuthorizedHTTPListenAddr: "http://localhost",
HTTPListenAddr: "localhost",
HTTPListenPort: 8080,
UnAuthorizedHTTPListenAddr: "http://localhost",
UnAuthorizedHTTPListenPort: 8081,
ServerGracefulShutdownTimeout: time.Second * 5,
HTTPServerReadTimeout: time.Second * 10,
HTTPServerWriteTimeout: time.Second * 10,
Expand All @@ -49,6 +53,8 @@ func TestNew(t *testing.T) {
config: Config{
HTTPListenAddr: "localhost",
HTTPListenPort: 8080,
UnAuthorizedHTTPListenAddr: "localhost",
UnAuthorizedHTTPListenPort: 8081,
ServerGracefulShutdownTimeout: time.Second * 5,
HTTPServerReadTimeout: time.Second * 10,
HTTPServerWriteTimeout: time.Second * 10,
Expand All @@ -61,6 +67,8 @@ func TestNew(t *testing.T) {
config: Config{
HTTPListenAddr: "localhost",
HTTPListenPort: 8080,
UnAuthorizedHTTPListenAddr: "localhost",
UnAuthorizedHTTPListenPort: 8081,
ServerGracefulShutdownTimeout: time.Second * 5,
HTTPServerReadTimeout: time.Second * 10,
HTTPServerWriteTimeout: time.Second * 10,
Expand Down Expand Up @@ -94,6 +102,7 @@ func TestNew(t *testing.T) {
t.Errorf("Expected server address to be %s:%d, but got %s", tc.config.HTTPListenAddr, tc.config.HTTPListenPort, server.authServer.httpServer.Addr)
}
}
server.Shutdown()
})
}
}
Expand All @@ -115,7 +124,6 @@ func TestServer_RegisterTo(t *testing.T) {
s.RegisterTo("/test_auth", testHandler, AUTH)
s.RegisterTo("/test_unauth", testHandler, UNAUTH)

// Test authorized server.
req := httptest.NewRequest(http.MethodGet, "/test_auth", nil)
w := httptest.NewRecorder()

Expand All @@ -126,7 +134,6 @@ func TestServer_RegisterTo(t *testing.T) {
t.Errorf("Expected status code %d for AUTH server, but got %d", http.StatusOK, resp.StatusCode)
}

// Test unauthorized server.
req = httptest.NewRequest(http.MethodGet, "/test_unauth", nil)
w = httptest.NewRecorder()

Expand Down Expand Up @@ -227,6 +234,10 @@ func TestRun(t *testing.T) {

func TestReadyHandler(t *testing.T) {
cfg := Config{
HTTPListenAddr: "localhost",
HTTPListenPort: 1234,
UnAuthorizedHTTPListenAddr: "localhost",
UnAuthorizedHTTPListenPort: 1235,
HTTPServerReadTimeout: 5 * time.Second,
HTTPServerWriteTimeout: 5 * time.Second,
HTTPServerIdleTimeout: 5 * time.Second,
Expand Down Expand Up @@ -278,4 +289,42 @@ func TestReadyHandler(t *testing.T) {
}
})
}
s.Shutdown()
}

func TestCheckPortAvailable(t *testing.T) {
tests := []struct {
name string
listenAddr string
listenPort int
wantAvailable bool
}{
{
name: "port available",
listenAddr: "localhost",
listenPort: 8080,
wantAvailable: true,
},
{
name: "port unavailable",
listenAddr: "localhost",
listenPort: 1234,
wantAvailable: false,
},
}

listener, err := net.Listen(DefaultNetwork, fmt.Sprintf("%s:%d", "localhost", 1234))
if err != nil {
t.Fatalf("Failed to create a listener: %v", err)
}
defer listener.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

available := checkPortAvailable(tt.listenAddr, tt.listenPort, DefaultNetwork)
if available != tt.wantAvailable {
t.Errorf("Expected port %d to be available: %v", tt.listenPort, tt.wantAvailable)
}
})
}
}

0 comments on commit 9841cb4

Please sign in to comment.