diff --git a/pkg/server/server.go b/pkg/server/server.go index f24fe6e..2757018 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,6 +1,7 @@ package server import ( + "bufio" "context" "crypto/rand" "crypto/subtle" @@ -234,7 +235,7 @@ func selectUsers(dbPool *pgxpool.Pool, logger *zerolog.Logger, ad authData) ([]u users, err := pgx.CollectRows(rows, pgx.RowToStructByName[user]) if err != nil { logger.Err(err).Msg("unable to CollectRows for users") - return nil, errors.New("unable to get rows for for users") + return nil, errors.New("unable to get rows for users") } return users, nil @@ -927,6 +928,164 @@ func insertServiceVersion(dbPool *pgxpool.Pool, serviceID pgtype.UUID, orgNameOr return serviceVersionResult, nil } +func generateCompleteVcl(sv selectVcl) (string, error) { + var b strings.Builder + + b.WriteString("vcl 4.1;\n") + b.WriteString("import std;\n") + b.WriteString("import proxy;\n") + b.WriteString("\n") + b.WriteString("backend haproxy_https {\n") + b.WriteString(" .path = \"/shared/haproxy_https\"\n") + b.WriteString("}\n") + b.WriteString("backend haproxy_http {\n") + b.WriteString(" .path = \"/shared/haproxy_http\"\n") + b.WriteString("}\n") + b.WriteString("\n") + + for i, origin := range sv.Origins { + b.WriteString(fmt.Sprintf("backend backend_%d {\n", i)) + b.WriteString(fmt.Sprintf(" .host = \"%s\";\n", origin.Host)) + b.WriteString(fmt.Sprintf(" .port = \"%d\";\n", origin.Port)) + if origin.TLS { + b.WriteString(" .via = haproxy_https;\n") + } else { + b.WriteString(" .via = haproxy_http;\n") + } + b.WriteString("}\n") + } + if len(sv.Origins) > 0 { + b.WriteString("\n") + } + + b.WriteString("sub vcl_recv {\n") + if len(sv.Domains) > 0 { + b.WriteString(" if ") + for i, domain := range sv.Domains { + if i > 0 { + b.WriteString(" && ") + } + b.WriteString(fmt.Sprintf("req.http.host != \"%s\"", domain)) + } + b.WriteString(" {\n") + b.WriteString(" return(synth(400,\"Unknown Host header.\"));\n") + b.WriteString(" }\n") + } + + if sv.VclRecvContent != "" { + b.WriteString(" # vcl_recv content from database\n") + scanner := bufio.NewScanner(strings.NewReader(sv.VclRecvContent)) + for scanner.Scan() { + if scanner.Text() != "" { + b.WriteString(" " + scanner.Text() + "\n") + } else { + b.WriteString("\n") + } + } + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("scanning VclRecvContent failed: %w", err) + } + } + b.WriteString("}\n") + + return b.String(), nil +} + +func selectVcls(dbPool *pgxpool.Pool, ad authData) ([]completeVcl, error) { + var rows pgx.Rows + var err error + if ad.superuser { + // Usage of JOIN with subqueries based on + // https://stackoverflow.com/questions/27622398/multiple-array-agg-calls-in-a-single-query + // (including separate version when having WHERE statement based on org). + rows, err = dbPool.Query( + context.Background(), + `SELECT + organizations.id AS org_id, + services.id AS service_id, + service_versions.version, + service_versions.active, + service_vcl_recv.content AS vcl_recv_content, + agg_domains.domains, + agg_origins.origins + FROM + organizations + JOIN services ON organizations.id = services.org_id + JOIN service_versions ON services.id = service_versions.service_id + JOIN service_vcl_recv ON service_versions.id = service_vcl_recv.service_version_id + JOIN ( + SELECT service_version_id, array_agg(domain ORDER BY domain) AS domains + FROM service_domains + GROUP BY service_version_id + ) AS agg_domains ON agg_domains.service_version_id = service_versions.id + JOIN ( + SELECT service_version_id, array_agg((host, port, tls) ORDER BY host, port) AS origins + FROM service_origins + GROUP BY service_version_id + ) AS agg_origins ON agg_origins.service_version_id = service_versions.id + ORDER BY organizations.name`, + ) + if err != nil { + return nil, fmt.Errorf("unable to query for vcls as superuser: %w", err) + } + } else if ad.orgID != nil { + rows, err = dbPool.Query( + context.Background(), + `SELECT + organizations.id AS org_id, + services.id AS service_id, + service_versions.version, + service_versions.active, + service_vcl_recv.content AS vcl_recv_content, + (SELECT + array_agg(domain ORDER BY domain) + FROM service_domains + WHERE service_version_id = service_versions.id + ) AS domains, + (SELECT + array_agg((host, port, tls) ORDER BY host, port) + FROM service_origins + WHERE service_version_id = service_versions.id + ) AS origins + FROM + organizations + JOIN services ON organizations.id = services.org_id + JOIN service_versions ON services.id = service_versions.service_id + JOIN service_vcl_recv ON service_versions.id = service_vcl_recv.service_version_id + WHERE organizations.id=$1 + ORDER BY organizations.name`, + *ad.orgID, + ) + if err != nil { + return nil, fmt.Errorf("unable to query for vcls as normal user: %w", err) + } + } else { + return nil, errForbidden + } + + selectedVcls, err := pgx.CollectRows(rows, pgx.RowToStructByName[selectVcl]) + if err != nil { + return nil, fmt.Errorf("unable to get rows for vcls: %w", err) + } + + var completeVcls []completeVcl + for _, sv := range selectedVcls { + vclContent, err := generateCompleteVcl(sv) + if err != nil { + return nil, fmt.Errorf("unable to generate complete vcl for selected vcl: %w", err) + } + completeVcls = append(completeVcls, completeVcl{ + OrgID: sv.OrgID, + ServiceID: sv.ServiceID, + Active: sv.Active, + Version: sv.Version, + Content: vclContent, + }) + } + + return completeVcls, nil +} + func newChiRouter(logger zerolog.Logger, dbPool *pgxpool.Pool) *chi.Mux { router := chi.NewMux() @@ -1322,6 +1481,31 @@ func setupHumaAPI(router *chi.Mux, dbPool *pgxpool.Pool) error { }, ) + huma.Get(api, "/api/v1/vcls", func(ctx context.Context, _ *struct{}, + ) (*completeVclsOutput, error) { + logger := zlog.Ctx(ctx) + + ad, ok := ctx.Value(authDataKey{}).(authData) + if !ok { + logger.Error().Msg("unable to read auth data from vcls handler") + return nil, errors.New("unable to read auth data from vcls handler") + } + + vcls, err := selectVcls(dbPool, ad) + if err != nil { + if errors.Is(err, errForbidden) { + return nil, huma.Error403Forbidden("not allowed to access resource") + } + logger.Err(err).Msg("unable to query vcls") + return nil, err + } + + resp := &completeVclsOutput{ + Body: vcls, + } + return resp, nil + }) + return nil } @@ -1387,6 +1571,28 @@ type origin struct { TLS bool `json:"tls"` } +type selectVcl struct { + OrgID pgtype.UUID `json:"org_id" doc:"ID of organization"` + ServiceID pgtype.UUID `json:"service_id" doc:"ID of service"` + Active bool `json:"active" example:"true" doc:"If the VCL is active"` + Version int64 `json:"version" example:"1" doc:"Version of the service"` + Domains []string `json:"domains" doc:"The domains used by the VCL"` + Origins []origin `json:"origins" doc:"The origins used by the VCL"` + VclRecvContent string `json:"vcl_recv_content" doc:"The vcl_recv content for the service"` +} + +type completeVcl struct { + OrgID pgtype.UUID `json:"org_id" doc:"ID of organization"` + ServiceID pgtype.UUID `json:"service_id" doc:"ID of service"` + Active bool `json:"active" example:"true" doc:"If the VCL is active"` + Version int64 `json:"version" example:"1" doc:"Version of the service"` + Content string `json:"content" doc:"The complete VCL loaded by varnish"` +} + +type completeVclsOutput struct { + Body []completeVcl +} + func Run(logger zerolog.Logger) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 36dc8a9..f0f1bbb 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -66,6 +66,8 @@ func prepareServer() (*httptest.Server, *pgxpool.Pool, error) { return nil, nil, errors.New("unable to parse PostgreSQL config string") } + fmt.Println(pgConfig.ConnString()) + dbPool, err := pgxpool.NewWithConfig(ctx, pgConfig) if err != nil { return nil, nil, errors.New("unable to create database pool") @@ -1152,3 +1154,91 @@ func TestPostServiceVersion(t *testing.T) { fmt.Printf("%s\n", jsonData) } } + +func TestGetVcls(t *testing.T) { + ts, dbPool, err := prepareServer() + if dbPool != nil { + defer dbPool.Close() + } + if err != nil { + t.Fatal(err) + } + defer ts.Close() + + tests := []struct { + description string + username string + password string + expectedStatus int + }{ + { + description: "successful superuser request", + username: "admin", + password: "adminpass1", + expectedStatus: http.StatusOK, + }, + { + description: "failed superuser request, bad password", + username: "admin", + password: "badadminpass1", + expectedStatus: http.StatusUnauthorized, + }, + { + description: "successful organization request", + username: "username1", + password: "password1", + expectedStatus: http.StatusOK, + }, + { + description: "failed organization request, bad password", + username: "username1", + password: "badpassword1", + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, test := range tests { + req, err := http.NewRequest("GET", ts.URL+"/api/v1/vcls", nil) + if err != nil { + t.Fatal(err) + } + + req.SetBasicAuth(test.username, test.password) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != test.expectedStatus { + r, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Fatalf("%s: GET vcls unexpected status code: %d (%s)", test.description, resp.StatusCode, string(r)) + } + + jsonData, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + fmt.Printf("%s\n", jsonData) + + if resp.StatusCode == http.StatusOK { + s := []struct { + Content string + }{} + + err = json.Unmarshal(jsonData, &s) + if err != nil { + t.Fatal(err) + } + + for _, content := range s { + fmt.Println(content.Content) + } + } + } +} diff --git a/pkg/server/testdata/migrations/00001_init.sql b/pkg/server/testdata/migrations/00001_init.sql index d7e6fde..47dffb1 100644 --- a/pkg/server/testdata/migrations/00001_init.sql +++ b/pkg/server/testdata/migrations/00001_init.sql @@ -67,10 +67,10 @@ CREATE TABLE service_origins ( host text NOT NULL CONSTRAINT non_empty CHECK(length(host)>0), port integer NOT NULL CONSTRAINT port_range CHECK(port >= 1 AND port <= 65535), tls boolean DEFAULT true NOT NULL, - UNIQUE(service_version_id, host) + UNIQUE(service_version_id, host, port) ); -CREATE TABLE service_vcl_rcv ( +CREATE TABLE service_vcl_recv ( id uuid PRIMARY KEY DEFAULT gen_random_uuid(), ts timestamptz NOT NULL DEFAULT now(), service_version_id uuid NOT NULL REFERENCES service_versions(id), diff --git a/pkg/server/testdata/migrations/00007_add_vcl_rcv.go b/pkg/server/testdata/migrations/00007_add_vcl_recv.go similarity index 84% rename from pkg/server/testdata/migrations/00007_add_vcl_rcv.go rename to pkg/server/testdata/migrations/00007_add_vcl_recv.go index 3341ac6..dd19a68 100644 --- a/pkg/server/testdata/migrations/00007_add_vcl_rcv.go +++ b/pkg/server/testdata/migrations/00007_add_vcl_recv.go @@ -26,7 +26,7 @@ func upAddVclRcv(ctx context.Context, tx *sql.Tx) error { { id: "00000000-0000-0000-0000-000000000028", serviceVersionID: "00000000-0000-0000-0000-000000000015", - file: "testdata/vcl/vcl_rcv/content1.vcl", + file: "testdata/vcl/vcl_recv/content1.vcl", }, } @@ -47,7 +47,7 @@ func upAddVclRcv(ctx context.Context, tx *sql.Tx) error { return err } - _, err = tx.Exec("INSERT INTO service_vcl_rcv (id, service_version_id, content) VALUES($1, $2, $3)", vclID, serviceVersionID, contentBytes) + _, err = tx.Exec("INSERT INTO service_vcl_recv (id, service_version_id, content) VALUES($1, $2, $3)", vclID, serviceVersionID, contentBytes) if err != nil { return err } diff --git a/pkg/server/testdata/migrations/00008_add_service_domains.sql b/pkg/server/testdata/migrations/00008_add_service_domains.sql new file mode 100644 index 0000000..4dba60d --- /dev/null +++ b/pkg/server/testdata/migrations/00008_add_service_domains.sql @@ -0,0 +1,6 @@ +-- +goose up +-- organization1, last version is active +INSERT INTO service_domains (id, service_version_id, domain) VALUES ('00000000-0000-0000-0000-000000000029', '00000000-0000-0000-0000-000000000015', 'www.example.se'); +INSERT INTO service_domains (id, service_version_id, domain) VALUES ('00000000-0000-0000-0000-000000000030', '00000000-0000-0000-0000-000000000015', 'www.example.com'); +-- +goose down +DELETE FROM service_domains; diff --git a/pkg/server/testdata/migrations/00009_add_service_origins.sql b/pkg/server/testdata/migrations/00009_add_service_origins.sql new file mode 100644 index 0000000..22fda88 --- /dev/null +++ b/pkg/server/testdata/migrations/00009_add_service_origins.sql @@ -0,0 +1,6 @@ +-- +goose up +-- organization1, last version is active +INSERT INTO service_origins (id, service_version_id, host, port, tls) VALUES ('00000000-0000-0000-0000-000000000031', '00000000-0000-0000-0000-000000000015', 'srv2.example.com', 80, false); +INSERT INTO service_origins (id, service_version_id, host, port, tls) VALUES ('00000000-0000-0000-0000-000000000032', '00000000-0000-0000-0000-000000000015', 'srv1.example.se', 443, true); +-- +goose down +DELETE FROM service_origins; diff --git a/pkg/server/testdata/vcl/vcl_rcv/content1.vcl b/pkg/server/testdata/vcl/vcl_recv/content1.vcl similarity index 60% rename from pkg/server/testdata/vcl/vcl_rcv/content1.vcl rename to pkg/server/testdata/vcl/vcl_recv/content1.vcl index 07a6862..fd443ae 100644 --- a/pkg/server/testdata/vcl/vcl_rcv/content1.vcl +++ b/pkg/server/testdata/vcl/vcl_recv/content1.vcl @@ -1,7 +1,7 @@ # The usage of the proxy module is possible because haproxy is configured # to set PROXY SSL headers for us. if (proxy.is_ssl()) { - std.syslog(180, "vcl_rcv: this is https"); + std.syslog(180, "vcl_recv: this is https"); } else { - std.syslog(180, "vcl_rcv: this is http"); + std.syslog(180, "vcl_recv: this is http"); }