diff --git a/client/nginx.go b/client/nginx.go index c6fb9f68..206ecf8f 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -41,11 +41,14 @@ var ErrUnsupportedVer = errors.New("API version of the client is not supported b // NginxClient lets you access NGINX Plus API. type NginxClient struct { - version int + apiVersion int apiEndpoint string httpClient *http.Client + checkAPI bool } +type Option func(*NginxClient) + type versions []int // UpstreamServer lets you configure HTTP upstreams. @@ -508,35 +511,66 @@ type WorkersHTTP struct { HTTPRequests HTTPRequests `json:"requests"` } -// NewNginxClient creates an NginxClient with the latest supported version. -func NewNginxClient(httpClient *http.Client, apiEndpoint string) (*NginxClient, error) { - return NewNginxClientWithVersion(httpClient, apiEndpoint, APIVersion) +// WithHTTPClient sets the HTTP client to use for accessing the API. +func WithHTTPClient(httpClient *http.Client) Option { + return func(o *NginxClient) { + o.httpClient = httpClient + } } -// NewNginxClientWithVersion creates an NginxClient with the given version of NGINX Plus API. -func NewNginxClientWithVersion(httpClient *http.Client, apiEndpoint string, version int) (*NginxClient, error) { - if !versionSupported(version) { - return nil, fmt.Errorf("API version %v is not supported by the client", version) +// WithAPIVersion sets the API version to use for accessing the API. +func WithAPIVersion(apiVersion int) Option { + return func(o *NginxClient) { + o.apiVersion = apiVersion } - versions, err := getAPIVersions(httpClient, apiEndpoint) - if err != nil { - return nil, fmt.Errorf("error accessing the API: %w", err) +} + +// WithCheckAPI sets the flag to check the API version of the server. +func WithCheckAPI() Option { + return func(o *NginxClient) { + o.checkAPI = true } - found := false - for _, v := range *versions { - if v == version { - found = true - break - } +} + +// NewNginxClient creates a new NginxClient. +func NewNginxClient(apiEndpoint string, opts ...Option) (*NginxClient, error) { + c := &NginxClient{ + httpClient: http.DefaultClient, + apiEndpoint: apiEndpoint, + apiVersion: APIVersion, + checkAPI: false, } - if !found { - return nil, ErrUnsupportedVer + + for _, opt := range opts { + opt(c) } - return &NginxClient{ - apiEndpoint: apiEndpoint, - httpClient: httpClient, - version: version, - }, nil + + if c.httpClient == nil { + return nil, fmt.Errorf("http client is not set") + } + + if !versionSupported(c.apiVersion) { + return nil, fmt.Errorf("API version %v is not supported by the client", c.apiVersion) + } + + if c.checkAPI { + versions, err := getAPIVersions(c.httpClient, apiEndpoint) + if err != nil { + return nil, fmt.Errorf("error accessing the API: %w", err) + } + found := false + for _, v := range *versions { + if v == c.apiVersion { + found = true + break + } + } + if !found { + return nil, fmt.Errorf("API version %v is not supported by the server", c.apiVersion) + } + } + + return c, nil } func versionSupported(n int) bool { @@ -807,7 +841,7 @@ func (client *NginxClient) get(path string, data interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.version, path) + url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -841,7 +875,7 @@ func (client *NginxClient) post(path string, input interface{}) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.version, path) + url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) jsonInput, err := json.Marshal(input) if err != nil { @@ -873,7 +907,7 @@ func (client *NginxClient) delete(path string, expectedStatusCode int) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.version, path) + path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, path, nil) if err != nil { @@ -898,7 +932,7 @@ func (client *NginxClient) patch(path string, input interface{}, expectedStatusC ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.version, path) + path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) jsonInput, err := json.Marshal(input) if err != nil { @@ -1359,7 +1393,7 @@ func (client *NginxClient) GetStreamZoneSync() (*StreamZoneSync, error) { // GetLocationZones returns http/location_zones stats. func (client *NginxClient) GetLocationZones() (*LocationZones, error) { var locationZones LocationZones - if client.version < 5 { + if client.apiVersion < 5 { return &locationZones, nil } err := client.get("http/location_zones", &locationZones) @@ -1373,7 +1407,7 @@ func (client *NginxClient) GetLocationZones() (*LocationZones, error) { // GetResolvers returns Resolvers stats. func (client *NginxClient) GetResolvers() (*Resolvers, error) { var resolvers Resolvers - if client.version < 5 { + if client.apiVersion < 5 { return &resolvers, nil } err := client.get("resolvers", &resolvers) @@ -1596,7 +1630,7 @@ func (client *NginxClient) UpdateStreamServer(upstream string, server StreamUpst // Version returns client's current N+ API version. func (client *NginxClient) Version() int { - return client.version + return client.apiVersion } func addPortToServer(server string) string { @@ -1618,7 +1652,7 @@ func addPortToServer(server string) string { // GetHTTPLimitReqs returns http/limit_reqs stats. func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) { var limitReqs HTTPLimitRequests - if client.version < 6 { + if client.apiVersion < 6 { return &limitReqs, nil } err := client.get("http/limit_reqs", &limitReqs) @@ -1631,7 +1665,7 @@ func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) { // GetHTTPConnectionsLimit returns http/limit_conns stats. func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, error) { var limitConns HTTPLimitConnections - if client.version < 6 { + if client.apiVersion < 6 { return &limitConns, nil } err := client.get("http/limit_conns", &limitConns) @@ -1644,7 +1678,7 @@ func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, err // GetStreamConnectionsLimit returns stream/limit_conns stats. func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections, error) { var limitConns StreamLimitConnections - if client.version < 6 { + if client.apiVersion < 6 { return &limitConns, nil } err := client.get("stream/limit_conns", &limitConns) @@ -1663,7 +1697,7 @@ func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections, // GetWorkers returns workers stats. func (client *NginxClient) GetWorkers() ([]*Workers, error) { var workers []*Workers - if client.version < 9 { + if client.apiVersion < 9 { return workers, nil } err := client.get("workers", &workers) diff --git a/client/nginx_test.go b/client/nginx_test.go index b8aa4eb8..0acf9914 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -1,6 +1,8 @@ package client import ( + "net/http" + "net/http/httptest" "reflect" "testing" ) @@ -518,3 +520,72 @@ func TestHaveSameParametersForStream(t *testing.T) { } } } + +func TestClientWithCheckAPI(t *testing.T) { + // Create a test server that returns supported API versions + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(`[4, 5, 6, 7]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + })) + defer ts.Close() + + // Test creating a new client with a supported API version on the server + client, err := NewNginxClient(ts.URL, WithAPIVersion(7), WithCheckAPI()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatalf("client is nil") + } + + // Test creating a new client with an unsupported API version on the server + client, err = NewNginxClient(ts.URL, WithAPIVersion(8), WithCheckAPI()) + if err == nil { + t.Fatalf("expected error, but got nil") + } + if client != nil { + t.Fatalf("expected client to be nil, but got %v", client) + } +} + +func TestClientWithAPIVersion(t *testing.T) { + // Test creating a new client with a supported API version on the client + client, err := NewNginxClient("http://api-url", WithAPIVersion(8)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatalf("client is nil") + } + + // Test creating a new client with an unsupported API version on the client + client, err = NewNginxClient("http://api-url", WithAPIVersion(3)) + if err == nil { + t.Fatalf("expected error, but got nil") + } + if client != nil { + t.Fatalf("expected client to be nil, but got %v", client) + } +} + +func TestClientWithHTTPClient(t *testing.T) { + // Test creating a new client passing a custom HTTP client + client, err := NewNginxClient("http://api-url", WithHTTPClient(&http.Client{})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatalf("client is nil") + } + + // Test creating a new client passing a nil HTTP client + client, err = NewNginxClient("http://api-url", WithHTTPClient(nil)) + if err == nil { + t.Fatalf("expected error, but got nil") + } + if client != nil { + t.Fatalf("expected client to be nil, but got %v", client) + } +} diff --git a/tests/client_no_stream_test.go b/tests/client_no_stream_test.go index 6a7221a4..cb29d465 100644 --- a/tests/client_no_stream_test.go +++ b/tests/client_no_stream_test.go @@ -1,7 +1,6 @@ package tests import ( - "net/http" "testing" "github.com/nginxinc/nginx-plus-go-client/client" @@ -13,8 +12,7 @@ import ( // The API returns a special error code that we can use to determine if the API // is misconfigured or of the stream block is missing. func TestStatsNoStream(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } diff --git a/tests/client_test.go b/tests/client_test.go index 27fb84e9..f580ad71 100644 --- a/tests/client_test.go +++ b/tests/client_test.go @@ -2,7 +2,6 @@ package tests import ( "net" - "net/http" "reflect" "testing" "time" @@ -34,8 +33,10 @@ var ( ) func TestStreamClient(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient( + helpers.GetAPIEndpoint(), + client.WithCheckAPI(), + ) if err != nil { t.Fatalf("Error when creating a client: %v", err) } @@ -254,8 +255,7 @@ func TestStreamClient(t *testing.T) { } func TestStreamUpstreamServer(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -302,8 +302,7 @@ func TestStreamUpstreamServer(t *testing.T) { } func TestClient(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error when creating a client: %v", err) } @@ -529,8 +528,7 @@ func TestClient(t *testing.T) { } func TestUpstreamServer(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -578,8 +576,7 @@ func TestUpstreamServer(t *testing.T) { } func TestStats(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -720,8 +717,7 @@ func TestStats(t *testing.T) { } func TestUpstreamServerDefaultParameters(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -770,8 +766,7 @@ func TestUpstreamServerDefaultParameters(t *testing.T) { } func TestStreamStats(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -848,8 +843,7 @@ func TestStreamStats(t *testing.T) { } func TestStreamUpstreamServerDefaultParameters(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -897,8 +891,7 @@ func TestStreamUpstreamServerDefaultParameters(t *testing.T) { func TestKeyValue(t *testing.T) { zoneName := "zone_one" - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -995,8 +988,7 @@ func TestKeyValue(t *testing.T) { func TestKeyValueStream(t *testing.T) { zoneName := "zone_one_stream" - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -1092,12 +1084,12 @@ func TestKeyValueStream(t *testing.T) { } func TestStreamZoneSync(t *testing.T) { - c1, err := client.NewNginxClient(&http.Client{}, helpers.GetAPIEndpoint()) + c1, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } - c2, err := client.NewNginxClient(&http.Client{}, helpers.GetAPIEndpointOfHelper()) + c2, err := client.NewNginxClient(helpers.GetAPIEndpointOfHelper()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) } @@ -1218,8 +1210,7 @@ func compareStreamUpstreamServers(x []client.StreamUpstreamServer, y []client.St } func TestUpstreamServerWithDrain(t *testing.T) { - httpClient := &http.Client{} - c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint()) + c, err := client.NewNginxClient(helpers.GetAPIEndpoint()) if err != nil { t.Fatalf("Error connecting to nginx: %v", err) }