diff --git a/.gitignore b/.gitignore index d6a9f88..3acd020 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ build/sandbox-* deploy/lambda/*.zip +cmd/sandbox-api/assets .dev.* +.vscode diff --git a/cmd/sandbox-api/handlers.go b/cmd/sandbox-api/handlers.go index 49bf36d..8be3329 100644 --- a/cmd/sandbox-api/handlers.go +++ b/cmd/sandbox-api/handlers.go @@ -18,19 +18,20 @@ import ( "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" - "github.com/rhpds/sandbox/internal/api/v1" + v1 "github.com/rhpds/sandbox/internal/api/v1" "github.com/rhpds/sandbox/internal/config" "github.com/rhpds/sandbox/internal/log" "github.com/rhpds/sandbox/internal/models" ) type BaseHandler struct { - dbpool *pgxpool.Pool - svc *dynamodb.DynamoDB - doc *openapi3.T - oaRouter oarouters.Router - awsAccountProvider models.AwsAccountProvider - OcpSandboxProvider models.OcpSandboxProvider + dbpool *pgxpool.Pool + svc *dynamodb.DynamoDB + doc *openapi3.T + oaRouter oarouters.Router + awsAccountProvider models.AwsAccountProvider + OcpSandboxProvider models.OcpSandboxProvider + azureSandboxProvider *models.AzureSandboxProvider } type AdminHandler struct { @@ -38,14 +39,23 @@ type AdminHandler struct { tokenAuth *jwtauth.JWTAuth } -func NewBaseHandler(svc *dynamodb.DynamoDB, dbpool *pgxpool.Pool, doc *openapi3.T, oaRouter oarouters.Router, awsAccountProvider models.AwsAccountProvider, OcpSandboxProvider models.OcpSandboxProvider) *BaseHandler { +func NewBaseHandler( + svc *dynamodb.DynamoDB, + dbpool *pgxpool.Pool, + doc *openapi3.T, + oaRouter oarouters.Router, + awsAccountProvider models.AwsAccountProvider, + OcpSandboxProvider models.OcpSandboxProvider, + azureSandboxProvider *models.AzureSandboxProvider, +) *BaseHandler { return &BaseHandler{ - svc: svc, - dbpool: dbpool, - doc: doc, - oaRouter: oaRouter, - awsAccountProvider: awsAccountProvider, - OcpSandboxProvider: OcpSandboxProvider, + svc: svc, + dbpool: dbpool, + doc: doc, + oaRouter: oaRouter, + awsAccountProvider: awsAccountProvider, + OcpSandboxProvider: OcpSandboxProvider, + azureSandboxProvider: azureSandboxProvider, } } @@ -227,6 +237,33 @@ func (h *BaseHandler) CreatePlacementHandler(w http.ResponseWriter, r *http.Requ tocleanup = append(tocleanup, &account) resources = append(resources, account) + case "AzureSandbox": + azureSandbox, err := h.azureSandboxProvider.Request( + placementRequest.ServiceUuid, + placementRequest.Annotations.Merge(request.Annotations), + ) + if err != nil { + // Cleanup previous Azure sandboxes + go func() { + for _, sandbox := range tocleanup { + if err := sandbox.Delete(); err != nil { + log.Logger.Error("Error deleting account", "error", err) + } + } + }() + w.WriteHeader(http.StatusInternalServerError) + render.Render(w, r, &v1.Error{ + Err: err, + HTTPStatusCode: http.StatusInternalServerError, + Message: "Error creating placement in Azure", + }) + log.Logger.Error("CreatePlacementHandler", "error", err) + return + } + + tocleanup = append(tocleanup, &azureSandbox) + resources = append(resources, azureSandbox) + default: w.WriteHeader(http.StatusBadRequest) render.Render(w, r, &v1.Error{ @@ -309,7 +346,6 @@ func (h *BaseHandler) HealthHandler(w http.ResponseWriter, r *http.Request) { // Get All placements func (h *BaseHandler) GetPlacementsHandler(w http.ResponseWriter, r *http.Request) { - placements, err := models.GetAllPlacements(h.dbpool) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -350,6 +386,7 @@ func (h *BaseHandler) GetPlacementHandler(w http.ResponseWriter, r *http.Request log.Logger.Error("GetPlacementHandler", "error", err) return } + // TODO: Add Azure provider if err := placement.LoadActiveResourcesWithCreds(h.awsAccountProvider, h.OcpSandboxProvider); err != nil { w.WriteHeader(http.StatusInternalServerError) render.Render(w, r, &v1.Error{ @@ -408,7 +445,7 @@ func (h *BaseHandler) DeletePlacementHandler(w http.ResponseWriter, r *http.Requ } placement.SetStatus("deleting") - go placement.Delete(h.awsAccountProvider, h.OcpSandboxProvider) + go placement.Delete(h.awsAccountProvider, h.OcpSandboxProvider, h.azureSandboxProvider) w.WriteHeader(http.StatusAccepted) render.Render(w, r, &v1.SimpleMessage{ @@ -628,7 +665,6 @@ func (h *BaseHandler) GetStatusPlacementHandler(w http.ResponseWriter, r *http.R } func (h *BaseHandler) GetJWTHandler(w http.ResponseWriter, r *http.Request) { - tokens, err := models.FetchAllTokens(h.dbpool) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -644,6 +680,7 @@ func (h *BaseHandler) GetJWTHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) render.Render(w, r, &tokens) } + func (h *AdminHandler) IssueLoginJWTHandler(w http.ResponseWriter, r *http.Request) { request := v1.TokenRequest{} @@ -869,7 +906,6 @@ func (h *BaseHandler) GetStatusRequestHandler(w http.ResponseWriter, r *http.Req // Get the request from the DB job, err := models.GetLifecyclePlacementJobByRequestID(h.dbpool, RequestID) - if err != nil { if err == pgx.ErrNoRows { // No placement request found, try any resource request @@ -915,7 +951,6 @@ func (h *BaseHandler) GetStatusRequestHandler(w http.ResponseWriter, r *http.Req // Get the status of the request status, err := job.GlobalStatus() - if err != nil { w.WriteHeader(http.StatusInternalServerError) render.Render(w, r, &v1.Error{ @@ -1049,7 +1084,6 @@ func (h *BaseHandler) DeleteReservationHandler(w http.ResponseWriter, r *http.Re name := chi.URLParam(r, "name") reservation, err := models.GetReservationByName(h.dbpool, name) - if err != nil { if err == pgx.ErrNoRows { w.WriteHeader(http.StatusNotFound) @@ -1096,7 +1130,6 @@ func (h *BaseHandler) UpdateReservationHandler(w http.ResponseWriter, r *http.Re name := chi.URLParam(r, "name") reservation, err := models.GetReservationByName(h.dbpool, name) - if err != nil { if err == pgx.ErrNoRows { w.WriteHeader(http.StatusNotFound) @@ -1172,7 +1205,6 @@ func (h *BaseHandler) GetReservationHandler(w http.ResponseWriter, r *http.Reque name := chi.URLParam(r, "name") reservation, err := models.GetReservationByName(h.dbpool, name) - if err != nil { if err == pgx.ErrNoRows { w.WriteHeader(http.StatusNotFound) @@ -1206,7 +1238,6 @@ func (h *BaseHandler) GetReservationResourcesHandler(w http.ResponseWriter, r *h name := chi.URLParam(r, "name") reservation, err := models.GetReservationByName(h.dbpool, name) - if err != nil { if err == pgx.ErrNoRows { w.WriteHeader(http.StatusNotFound) @@ -1228,7 +1259,6 @@ func (h *BaseHandler) GetReservationResourcesHandler(w http.ResponseWriter, r *h } accounts, err := h.awsAccountProvider.FetchAllByReservation(reservation.Name) - if err != nil { log.Logger.Error("GET accounts", "error", err) diff --git a/cmd/sandbox-api/main.go b/cmd/sandbox-api/main.go index 481f753..b4fa4e6 100644 --- a/cmd/sandbox-api/main.go +++ b/cmd/sandbox-api/main.go @@ -132,6 +132,15 @@ func main() { // --------------------------------------------------------------------- OcpSandboxProvider := models.NewOcpSandboxProvider(dbPool, vaultSecret) + // --------------------------------------------------------------------- + // Azure + // --------------------------------------------------------------------- + azureSandboxProvider, err := models.NewAzureSandboxProvider(dbPool, vaultSecret) + if err != nil { + log.Logger.Error("Error creating AzureSandboxProvider", "error", err) + os.Exit(1) + } + // --------------------------------------------------------------------- // Setup JWT // --------------------------------------------------------------------- @@ -155,7 +164,15 @@ func main() { accountHandler := NewAccountHandler(awsAccountProvider, OcpSandboxProvider) // Factory for handlers which need connections to both databases - baseHandler := NewBaseHandler(awsAccountProvider.Svc, dbPool, doc, oaRouter, awsAccountProvider, OcpSandboxProvider) + baseHandler := NewBaseHandler( + awsAccountProvider.Svc, + dbPool, + doc, + oaRouter, + awsAccountProvider, + OcpSandboxProvider, + azureSandboxProvider, + ) // Admin handler adds tokenAuth to the baseHandler adminHandler := NewAdminHandler(baseHandler, tokenAuth) diff --git a/db/migrations/013_azure_sandboxes.up.sql b/db/migrations/013_azure_sandboxes.up.sql new file mode 100644 index 0000000..13f9cee --- /dev/null +++ b/db/migrations/013_azure_sandboxes.up.sql @@ -0,0 +1,2 @@ +-- Add AzureSandbox value to the resource_type_enum +ALTER TYPE resource_type_enum ADD VALUE 'AzureSandbox'; \ No newline at end of file diff --git a/docs/api-reference/swagger.yaml b/docs/api-reference/swagger.yaml index ced4ef2..b163113 100644 --- a/docs/api-reference/swagger.yaml +++ b/docs/api-reference/swagger.yaml @@ -2217,6 +2217,7 @@ components: enum: - AwsSandbox - OcpSandbox + - AzureSandbox Reservation: description: Reservation record diff --git a/internal/api/azure/application.go b/internal/api/azure/application.go new file mode 100644 index 0000000..bb56b9c --- /dev/null +++ b/internal/api/azure/application.go @@ -0,0 +1,234 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type application struct { + AppID string + DisplayName string + Password string +} + +// createApplication creates a new application and generate random password. +func (g *graphClient) createApplication(name string) (*application, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + requestBody := struct { + DisplayName string `json:"displayName"` + PasswordCredentials []struct { + DisplayName string `json:"displayName"` + } `json:"passwordCredentials"` + }{ + DisplayName: name, + PasswordCredentials: []struct { + DisplayName string `json:"displayName"` + }{ + {DisplayName: "rbac"}, + }, + } + payloadBytes, err := json.Marshal(requestBody) + if err != nil { + return nil, err + } + + req, err := http.NewRequest( + "POST", + "https://graph.microsoft.com/v1.0/applications", + bytes.NewBuffer(payloadBytes), + ) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + req.Header.Add("Content-Type", "application/json") + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + if response.StatusCode != http.StatusCreated { + // Graph API reference tells nothing about possible errors or the + // error response format. + return nil, + fmt.Errorf("failed to create the application: %s, server response: %s", + name, string(responseData)) + } + + responseBody := struct { + AppID string `json:"appId"` + DisplayName string `json:"displayName"` + PasswordCredentials []struct { + DisplayName string `json:"displayName"` + SecretText string `json:"secretText"` + } `json:"passwordCredentials"` + }{} + err = json.Unmarshal(responseData, &responseBody) + if err != nil { + return nil, err + } + + // Only first password is used, so ignore the rest. This is a simplification + // and may not be correct in all cases. + return &application{ + AppID: responseBody.AppID, + DisplayName: responseBody.DisplayName, + Password: responseBody.PasswordCredentials[0].SecretText, + }, nil +} + +// getApplicationObjectIDs returns the object IDs of the applications with the +// given name. +func (g *graphClient) getApplicationObjectIDs(name string) ([]string, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf( + "https://graph.microsoft.com/v1.0/applications"+ + "?$search=\"displayName:%s\"&$count=true&$select=Id", + name, + ), + nil, + ) + if err != nil { + return nil, err + } + req.Header.Add("ConsistencyLevel", "eventual") + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + if response.StatusCode != http.StatusOK { + // Graph API reference tells nothing about possible errors or the + // error response format. + return nil, + fmt.Errorf("failed to get the application object IDs, server response: %s", + string(responseData)) + } + + responseBody := struct { + Value []struct { + ID string `json:"id"` + } `json:"value"` + }{} + err = json.Unmarshal(responseData, &responseBody) + if err != nil { + return nil, err + } + + objectIDs := make([]string, 0, len(responseBody.Value)) + for _, app := range responseBody.Value { + objectIDs = append(objectIDs, app.ID) + } + + return objectIDs, nil +} + +// deleteApplication deletes the application with the given object ID. +func (g *graphClient) deleteApplication(objectID string) error { + err := g.refreshToken() + if err != nil { + return err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "DELETE", + "https://graph.microsoft.com/v1.0/applications/"+objectID, + nil, + ) + if err != nil { + return err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusNoContent { + // Graph API reference tells nothing about possible errors or the + // error response format. + return fmt.Errorf("failed to delete the application with ID: %s", objectID) + } + + return nil +} + +// permanentDeleteApplication deletes the application with the given object ID +// permanently. This operation cannot be undone. +func (g *graphClient) permanentDeleteApplication(objectID string) error { + err := g.refreshToken() + if err != nil { + return err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "DELETE", + "https://graph.microsoft.com/v1.0/directory/deletedItems/"+objectID, + nil, + ) + if err != nil { + return err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusNoContent { + // Graph API reference tells nothing about possible errors or the + // error response format. + return fmt.Errorf("failed to permanently delete the application with ID: %s", objectID) + } + + return nil +} diff --git a/internal/api/azure/dns.go b/internal/api/azure/dns.go new file mode 100644 index 0000000..fdbe568 --- /dev/null +++ b/internal/api/azure/dns.go @@ -0,0 +1,118 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type dnsZoneParameters struct { + Tags map[string]string + SubscriptionID string + ResourceGroupName string + ZoneName string + Location string +} + +type dnsZone struct { + Id string + Name string + Type string + Location string +} + +// createDNSZone creates a new DNS zone in the specified resource group. +func (g *managementClient) createDNSZone(param dnsZoneParameters) (*dnsZone, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restEndpoint := fmt.Sprintf( + "https://management.azure.com/"+ + "subscriptions/%s/"+ + "resourceGroups/%s/"+ + "providers/Microsoft.Network/dnsZones/%s"+ + "?api-version=2018-05-01", + param.SubscriptionID, + param.ResourceGroupName, + param.ZoneName, + ) + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + payloadBytes, err := json.Marshal( + struct { + Tags map[string]string `json:"tags,omitempty"` + Location string `json:"location"` + }{ + Location: param.Location, + Tags: param.Tags, + }, + ) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", restEndpoint, bytes.NewBuffer(payloadBytes)) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + req.Header.Add("Content-Type", "application/json") + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK, http.StatusCreated: + zoneInfo := struct { + Id string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Location string `json:"location"` + }{} + err = json.Unmarshal(responseData, &zoneInfo) + if err != nil { + return nil, err + } + + return &dnsZone{ + Id: zoneInfo.Id, + Name: zoneInfo.Name, + Type: zoneInfo.Type, + Location: zoneInfo.Location, + }, nil + + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return nil, err + } + + return nil, fmt.Errorf( + "error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message, + ) + } +} diff --git a/internal/api/azure/graph_client.go b/internal/api/azure/graph_client.go new file mode 100644 index 0000000..3dc5fbf --- /dev/null +++ b/internal/api/azure/graph_client.go @@ -0,0 +1,39 @@ +package azure + +import ( + "time" +) + +type graphClient struct { + oauth *oauth2Client + token *oauth2Token +} + +// initGraphClient initializes a new MS Graph API client with the given +// tenantId, clientId, and secret. +func initGraphClient(tenantId string, clientId string, secret string) *graphClient { + token := &oauth2Token{ + Expires: time.Unix(0, 0).UTC(), + } + + return &graphClient{ + token: token, + oauth: oauthInit(tenantId, clientId, secret), + } +} + +// refreshToken checks if the current access token is about (~ 5 minutes) to +// expire and requests a new token if necessary. +func (g *graphClient) refreshToken() error { + difference := time.Until(g.token.Expires) + if difference <= (5 * time.Minute) { + token, err := g.oauth.requestToken("https://graph.microsoft.com") + if err != nil { + return err + } + + g.token = token + } + + return nil +} diff --git a/internal/api/azure/management_client.go b/internal/api/azure/management_client.go new file mode 100644 index 0000000..cd62208 --- /dev/null +++ b/internal/api/azure/management_client.go @@ -0,0 +1,39 @@ +package azure + +import ( + "time" +) + +type managementClient struct { + oauth *oauth2Client + token *oauth2Token +} + +// initManagementClient initializes a new Azure API client (ResourceManagement scope)\ +// with the given tenantId, clientId, and secret. +func initManagementClient(tenantId string, clientId string, secret string) *managementClient { + token := &oauth2Token{ + Expires: time.Unix(0, 0).UTC(), + } + + return &managementClient{ + token: token, + oauth: oauthInit(tenantId, clientId, secret), + } +} + +// refreshToken checks if the current access token is about (~ 5 minutes) to +// expire and requests a new token if necessary. +func (g *managementClient) refreshToken() error { + difference := time.Until(g.token.Expires) + if difference <= (5 * time.Minute) { + token, err := g.oauth.requestToken("https://management.azure.com") + if err != nil { + return err + } + + g.token = token + } + + return nil +} diff --git a/internal/api/azure/oauth2.go b/internal/api/azure/oauth2.go new file mode 100644 index 0000000..9312b2f --- /dev/null +++ b/internal/api/azure/oauth2.go @@ -0,0 +1,97 @@ +package azure + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type oauth2Token struct { + TokenType string + Expires time.Time + AccessToken string +} + +type oauth2Client struct { + tenantId string + clientId string + secret string +} + +// oauthInit initializes a new OAuth2 client with the given tenantId, clientId, +// and secret. +func oauthInit(tenantId string, clientId string, secret string) *oauth2Client { + return &oauth2Client{ + tenantId: tenantId, + clientId: clientId, + secret: secret, + } +} + +// requestToken retrieves OAuth2 token for the specified scope. +func (o *oauth2Client) requestToken(scope string) (*oauth2Token, error) { + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + v := url.Values{} + v.Add("client_id", o.clientId) + v.Add("scope", scope+"/.default") + v.Add("client_secret", o.secret) + v.Add("grant_type", "client_credentials") + + response, err := restClient.PostForm( + fmt.Sprintf( + "https://login.microsoftonline.com/%s/oauth2/v2.0/token", + o.tenantId), + v, + ) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK: + oauthToken := struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + }{} + err = json.Unmarshal(responseData, &oauthToken) + if err != nil { + return nil, err + } + + return &oauth2Token{ + TokenType: oauthToken.TokenType, + Expires: time.Now().Add(time.Duration(oauthToken.ExpiresIn) * time.Second).UTC(), + AccessToken: oauthToken.AccessToken, + }, nil + + case http.StatusBadRequest, http.StatusUnauthorized: + error := struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + Timestamp string `json:"timestamp"` + TraceID string `json:"trace_id"` + CorrelationID string `json:"correlation_id"` + ErrorCodes []int `json:"error_codes"` + }{} + err = json.Unmarshal(responseData, &error) + if err != nil { + return nil, err + } + return nil, fmt.Errorf("can't get token: %s", error.ErrorDescription) + default: + return nil, fmt.Errorf("unexpected status code: %d", response.StatusCode) + } +} diff --git a/internal/api/azure/pool.go b/internal/api/azure/pool.go new file mode 100644 index 0000000..231dae3 --- /dev/null +++ b/internal/api/azure/pool.go @@ -0,0 +1,87 @@ +package azure + +import ( + "fmt" + "io" + "net/http" + "time" +) + +type poolClient struct { + projectTag string + poolID string + poolAPISecret string +} + +// initPoolClient initializes a new Subscription Pool management API client. +func InitPoolClient(projectTag string, poolId string, poolAPISecret string) *poolClient { + return &poolClient{ + projectTag: projectTag, + poolID: poolId, + poolAPISecret: poolAPISecret, + } +} + +// allocatePool requests a new Subscription from the pool. +func (pc *poolClient) AllocatePool() (string, error) { + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf( + "https://rhpdspoolhandler.azurewebsites.net/api/get/%s/%s?code=%s", + pc.projectTag, + pc.poolID, + pc.poolAPISecret), + nil) + if err != nil { + return "", err + } + + response, err := restClient.Do(req) + if err != nil { + return "", err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return "", err + } + + return string(responseData), nil +} + +// releasePool releases allocated Subscription back to pool. +func (pc *poolClient) ReleasePool() error { + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf( + "https://rhpdspoolhandler.azurewebsites.net/api/release/%s/%s?code=%s", + pc.projectTag, + pc.poolID, + pc.poolAPISecret), + nil) + if err != nil { + return err + } + + response, err := restClient.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + _, err = io.ReadAll(response.Body) + if err != nil { + return err + } + + return nil +} diff --git a/internal/api/azure/resource_group.go b/internal/api/azure/resource_group.go new file mode 100644 index 0000000..3383fa4 --- /dev/null +++ b/internal/api/azure/resource_group.go @@ -0,0 +1,248 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type resourceGroupParameters struct { + Tags map[string]string + SubscriptionId string + ResourceGroupName string + Location string +} + +type resourceGroup struct { + Id string + Name string + Location string + ProvisioningState string +} + +// createResourceGroup creates a new Resource Group. +func (g *managementClient) createResourceGroup(param resourceGroupParameters) (*resourceGroup, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + payloadBytes, err := json.Marshal( + struct { + Tags map[string]string `json:"tags,omitempty"` + Location string `json:"location"` + }{ + Location: param.Location, + Tags: param.Tags, + }, + ) + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "PUT", + fmt.Sprintf( + "https://management.azure.com/subscriptions/%s/resourcegroups/%s?api-version=2021-04-01", + strings.Trim(param.SubscriptionId, "/"), + strings.Trim(param.ResourceGroupName, "/"), + ), + bytes.NewReader(payloadBytes)) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + req.Header.Add("Content-Type", "application/json") + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK, http.StatusCreated: + groupInfo := struct { + Id string `json:"id"` + Name string `json:"name"` + Location string `json:"location"` + Properties struct { + ProvisioningState string `json:"provisioningState"` + } `json:"properties"` + }{} + err = json.Unmarshal(responseData, &groupInfo) + if err != nil { + return nil, err + } + + return &resourceGroup{ + Id: groupInfo.Id, + Name: groupInfo.Name, + Location: groupInfo.Location, + ProvisioningState: groupInfo.Properties.ProvisioningState, + }, nil + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return nil, err + } + + return nil, fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } +} + +// listResourceGroups returns a list of Resource Groups for the Subscription. +func (g *managementClient) listResourceGroups(subscriptionId string) ([]resourceGroup, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf( + "https://management.azure.com/subscriptions/%s/resourcegroups?api-version=2021-04-01", + subscriptionId), + nil) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK: + groupInfo := struct { + Value []struct { + Id string `json:"id"` + Name string `json:"name"` + Location string `json:"location"` + Properties struct { + ProvisioningState string `json:"provisioningState"` + } `json:"properties"` + } `json:"value"` + }{} + err = json.Unmarshal(responseData, &groupInfo) + if err != nil { + return nil, err + } + + var groups []resourceGroup + + for _, rg := range groupInfo.Value { + groups = append(groups, resourceGroup{ + Id: rg.Id, + Name: rg.Name, + Location: rg.Location, + ProvisioningState: rg.Properties.ProvisioningState, + }) + } + + return groups, nil + + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return nil, err + } + + return nil, fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } +} + +// deleteResourceGroup deletes a resource group. +func (g *managementClient) deleteResourceGroup(resourceGroupID string) error { + err := g.refreshToken() + if err != nil { + return err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "DELETE", + fmt.Sprintf( + "https://management.azure.com/%s?api-version=2021-04-01", + resourceGroupID), + nil) + if err != nil { + return err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + switch response.StatusCode { + case http.StatusOK, http.StatusAccepted: + return nil + default: + responseData, err := io.ReadAll(response.Body) + if err != nil { + return err + } + + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return err + } + + return fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } +} diff --git a/internal/api/azure/role.go b/internal/api/azure/role.go new file mode 100644 index 0000000..268ed05 --- /dev/null +++ b/internal/api/azure/role.go @@ -0,0 +1,353 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/google/uuid" +) + +type roleDefinition struct { + ID string + Type string + Name string +} + +type roleAssignment struct { + ID string + Type string + Name string +} + +// getRoleDefinition searches for the roleName at the specified scope. +func (g *managementClient) getRoleDefinition(scope string, roleName string) (*roleDefinition, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf( + "https://management.azure.com/%s/providers/Microsoft.Authorization/roleDefinitions?api-version=2022-04-01", + strings.Trim(scope, "")), + nil, + ) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK: + roleDefinitions := struct { + Value []struct { + Properties struct { + RoleName string `json:"roleName"` + Type string `json:"type"` + Description string `json:"description"` + } `json:"properties"` + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + } `json:"value"` + }{} + err = json.Unmarshal(responseData, &roleDefinitions) + if err != nil { + return nil, err + } + + for _, role := range roleDefinitions.Value { + if role.Properties.RoleName == roleName { + return &roleDefinition{ + ID: role.ID, + Type: role.Type, + Name: role.Name, + }, nil + } + } + + return nil, fmt.Errorf("role definition for \"%s\" role not found", roleName) + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return nil, err + } + + return nil, fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } +} + +// createRoleAssignment creates a new role assignment at the specified scope. If +// the role assignment already exists, it returns details about the existing role +// assignment. +func (g *managementClient) createRoleAssignment( + scope string, + roleDefinitionId string, + principalId string, + principalType string, +) (*roleAssignment, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + requestBody := struct { + Properties struct { + RoleDefinitionID string `json:"roleDefinitionId"` + PrincipalID string `json:"principalId"` + PrincipalType string `json:"principalType"` + } `json:"properties"` + }{} + requestBody.Properties.RoleDefinitionID = roleDefinitionId + requestBody.Properties.PrincipalID = principalId + requestBody.Properties.PrincipalType = principalType + payloadBytes, err := json.Marshal(requestBody) + if err != nil { + return nil, err + } + + req, err := http.NewRequest( + "PUT", + fmt.Sprintf( + "https://management.azure.com/%s/providers/Microsoft.Authorization/roleAssignments/%s?api-version=2022-04-01", + strings.Trim(scope, "/"), + uuid.New(), + ), + bytes.NewReader(payloadBytes), + ) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + req.Header.Add("Content-Type", "application/json") + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK, http.StatusCreated: + assignment := struct { + Properties struct { + RoleDefinitionID string `json:"roleDefinitionId"` + PrincipalID string `json:"principalId"` + PrincipalType string `json:"principalType"` + Scope string `json:"scope"` + } `json:"properties"` + Id string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + }{} + err = json.Unmarshal(responseData, &assignment) + if err != nil { + return nil, err + } + + return &roleAssignment{ + ID: assignment.Id, + Type: assignment.Type, + Name: assignment.Name, + }, nil + + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return nil, err + } + + return nil, fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + + } +} + +// getRoleAssignments search for the role assignments at the specified scope and +// role definition ID. +func (g *managementClient) getRoleAssignments( + scope string, + roleDefinitionId string, +) ([]roleAssignment, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf( + "https://management.azure.com/%s/providers/Microsoft.Authorization/roleAssignments?api-version=2022-04-01", + strings.Trim(scope, "")), + nil) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK: + roleAssignments := struct { + Value []struct { + Properties struct { + RoleDefinitionID string `json:"roleDefinitionId"` + PrincipalID string `json:"principalId"` + } `json:"properties"` + Id string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + } `json:"value"` + }{} + err = json.Unmarshal(responseData, &roleAssignments) + if err != nil { + return nil, err + } + + var assignments []roleAssignment + for _, assignment := range roleAssignments.Value { + if assignment.Properties.RoleDefinitionID == roleDefinitionId { + assignments = append(assignments, roleAssignment{ + ID: assignment.Id, + Type: assignment.Type, + Name: assignment.Name, + }) + } + } + + return assignments, nil + + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return nil, err + } + + return nil, fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } +} + +// deleteRoleAssignment deletes the role assignment. +func (g *managementClient) deleteRoleAssignment( + roleAssignmentId string, +) error { + err := g.refreshToken() + if err != nil { + return err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "DELETE", + fmt.Sprintf( + "https://management.azure.com/%s?api-version=2022-04-01", + strings.Trim(roleAssignmentId, "")), + nil) + if err != nil { + return err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return err + } + + switch response.StatusCode { + case http.StatusOK, http.StatusNoContent: + return nil + + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return err + } + + return fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } +} diff --git a/internal/api/azure/sandbox.go b/internal/api/azure/sandbox.go new file mode 100644 index 0000000..db5c333 --- /dev/null +++ b/internal/api/azure/sandbox.go @@ -0,0 +1,301 @@ +package azure + +const ( + sandboxRoleName = "Custom-Owner (Block Billing and Subscription deletion)" + rgNamePrefix = "openenv-" + rgDefaultLocation = "eastus" + dnsDefaultLocation = "Global" + defaultAppPrefix = "api://openenv-" +) + +type AzureCredentials struct { + TenantID string + ClientID string + Secret string +} + +type SandboxClient struct { + graphClient *graphClient + managementClient *managementClient +} + +type SandboxInfo struct { + SubscriptionName string + SubscriptionId string + ResourceGroupName string + AppID string + DisplayName string + Password string +} + +func InitSandboxClient( /*pool Pool, */ credentials AzureCredentials) *SandboxClient { + gc := initGraphClient( + credentials.TenantID, + credentials.ClientID, + credentials.Secret, + ) + + mc := initManagementClient( + credentials.TenantID, + credentials.ClientID, + credentials.Secret, + ) + + return &SandboxClient{ + graphClient: gc, + managementClient: mc, + } +} + +func (sc *SandboxClient) CreateSandboxEnvironment( + subscriptionName string, + requestorEmail string, + guid string, + costCenter string, + zoneDomain string, +) (*SandboxInfo, error) { + adUser, err := sc.graphClient.getUser(requestorEmail) + if err != nil { + return nil, err + } + + subscription, err := sc.managementClient.getSubscription(subscriptionName) + if err != nil { + return nil, err + } + + err = sc.setSandboxTags(guid, requestorEmail, costCenter, subscription.SubscriptionFQID) + if err != nil { + return nil, err + } + + err = sc.createRoleAssignment(subscription.SubscriptionFQID, adUser.Id, "User") + if err != nil { + return nil, err + } + + rgName, err := sc.createResourceGroup(subscription.SubscriptionId, guid) + if err != nil { + return nil, err + } + + err = sc.createDNSZone(subscription.SubscriptionId, guid, rgName, zoneDomain) + if err != nil { + return nil, err + } + + appDetails, err := sc.registerApplication( + subscription.SubscriptionFQID, + defaultAppPrefix+guid) + if err != nil { + return nil, err + } + + return &SandboxInfo{ + SubscriptionName: subscriptionName, + SubscriptionId: subscription.SubscriptionId, + ResourceGroupName: rgName, + AppID: appDetails.AppID, + DisplayName: appDetails.DisplayName, + Password: appDetails.Password, + }, nil +} + +func (sc *SandboxClient) CleanupSandboxEnvironment(subscriptionName string, guid string) error { + subscription, err := sc.managementClient.getSubscription(subscriptionName) + if err != nil { + return err + } + + err = sc.deleteResourceGroups(subscription.SubscriptionId) + if err != nil { + return err + } + + err = sc.deleteApplications(defaultAppPrefix + guid) + if err != nil { + return err + } + + err = sc.deleteRoleAssignments(subscription.SubscriptionFQID) + if err != nil { + return err + } + + err = sc.deleteSandboxTags(subscription.SubscriptionFQID) + if err != nil { + return err + } + + return nil +} + +func (sc *SandboxClient) setSandboxTags( + guid string, + requestorEmail string, + costCenter string, + scope string, +) error { + tags := make(map[string]string) + tags["GUID"] = guid + tags["EMAIL"] = requestorEmail + tags["cost-center"] = costCenter + + err := sc.managementClient.setTags(scope, tags) + if err != nil { + return err + } + + return nil +} + +func (sc *SandboxClient) deleteSandboxTags(scope string) error { + tags := make(map[string]string) + tags["GUID"] = "" + tags["EMAIL"] = "" + err := sc.managementClient.updateTags(scope, tags, "delete") + if err != nil { + return err + } + + return nil +} + +func (sc *SandboxClient) createRoleAssignment(scope string, principalID string, principalType string) error { + roleDefinition, err := sc.managementClient.getRoleDefinition( + scope, + sandboxRoleName) + if err != nil { + return err + } + + _, err = sc.managementClient.createRoleAssignment( + scope, + roleDefinition.ID, + principalID, + principalType, + ) + if err != nil { + return err + } + + return nil +} + +func (sc *SandboxClient) deleteRoleAssignments(scope string) error { + roleDefinition, err := sc.managementClient.getRoleDefinition( + scope, + sandboxRoleName) + if err != nil { + return err + } + + roleAssignments, err := sc.managementClient.getRoleAssignments( + scope, + roleDefinition.ID) + if err != nil { + return err + } + + for _, assignment := range roleAssignments { + err = sc.managementClient.deleteRoleAssignment(assignment.ID) + if err != nil { + return err + } + } + + return nil +} + +func (sc *SandboxClient) createResourceGroup(subscriptionId string, guid string) (string, error) { + rgTags := make(map[string]string) + rgTags["GUID"] = guid + rgParams := resourceGroupParameters{ + SubscriptionId: subscriptionId, + ResourceGroupName: rgNamePrefix + guid, + Location: rgDefaultLocation, + Tags: rgTags, + } + + rg, err := sc.managementClient.createResourceGroup(rgParams) + if err != nil { + return "", err + } + + return rg.Name, nil +} + +func (sc *SandboxClient) deleteResourceGroups(subscriptionId string) error { + rgs, err := sc.managementClient.listResourceGroups(subscriptionId) + if err != nil { + return err + } + + for _, rg := range rgs { + err = sc.managementClient.deleteResourceGroup(rg.Id) + if err != nil { + return err + } + } + + return nil +} + +func (sc *SandboxClient) createDNSZone(subscriptionId string, guid string, rgName string, zoneDomain string) error { + dnsTags := make(map[string]string) + dnsTags["GUID"] = guid + dnsZoneParams := dnsZoneParameters{ + SubscriptionID: subscriptionId, + ResourceGroupName: rgName, + ZoneName: guid + "." + zoneDomain, + Location: dnsDefaultLocation, + Tags: dnsTags, + } + + _, err := sc.managementClient.createDNSZone(dnsZoneParams) + if err != nil { + return err + } + + return nil +} + +func (sc *SandboxClient) registerApplication(scope string, name string) (*application, error) { + app, err := sc.graphClient.createApplication(name) + if err != nil { + return nil, err + } + + sp, err := sc.graphClient.createServicePrincipal(app.AppID) + if err != nil { + return nil, err + } + + err = sc.createRoleAssignment(scope, sp.id, "ServicePrincipal") + if err != nil { + return nil, err + } + + return app, nil +} + +func (sc *SandboxClient) deleteApplications(name string) error { + appIds, err := sc.graphClient.getApplicationObjectIDs(name) + if err != nil { + return err + } + + for _, id := range appIds { + err = sc.graphClient.deleteApplication(id) + if err != nil { + return err + } + + err = sc.graphClient.permanentDeleteApplication(id) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/api/azure/service_principalp.go b/internal/api/azure/service_principalp.go new file mode 100644 index 0000000..e20d235 --- /dev/null +++ b/internal/api/azure/service_principalp.go @@ -0,0 +1,75 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type servicePrincipal struct { + id string +} + +func (g *graphClient) createServicePrincipal(appID string) (*servicePrincipal, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + requestBody := struct { + AppID string `json:"appId"` + }{ + AppID: appID, + } + + payloadBytes, err := json.Marshal(requestBody) + if err != nil { + return nil, err + } + + req, err := http.NewRequest( + "POST", + "https://graph.microsoft.com/v1.0/servicePrincipals", + bytes.NewBuffer(payloadBytes), + ) + if err != nil { + return nil, err + } + + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + req.Header.Add("Content-type", "application/json") + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + if response.StatusCode != http.StatusCreated { + // Graph API reference has no information about error codes + // returned by this endpoint. + return nil, fmt.Errorf("failed to create service principal: %s", response.Status) + } + + responseBody := struct { + ID string `json:"id"` + }{} + err = json.Unmarshal(responseData, &responseBody) + if err != nil { + return nil, err + } + return &servicePrincipal{ + id: responseBody.ID, + }, nil +} diff --git a/internal/api/azure/subscription.go b/internal/api/azure/subscription.go new file mode 100644 index 0000000..c66dab7 --- /dev/null +++ b/internal/api/azure/subscription.go @@ -0,0 +1,74 @@ +package azure + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type subscription struct { + SubscriptionId string + SubscriptionFQID string + DisplayName string +} + +// Retrieves the subscription details for the given subscription name. +// It uses the Microsoft OAuth2 client to request the subscription details. Returns +// the Subscription details or an error if the subscription was not found. +func (g *managementClient) getSubscription(name string) (*subscription, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + "https://management.azure.com/subscriptions?api-version=2022-12-01", + nil) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + // Base on the Azure REST API reference, the only response code possible is 200 + subscriptions := struct { + Value []struct { + Id string `json:"id"` + SubscriptionId string `json:"subscriptionId"` + DisplayName string `json:"displayName"` + } `json:"value"` + }{} + err = json.Unmarshal(responseData, &subscriptions) + if err != nil { + return nil, err + } + + for _, sub := range subscriptions.Value { + if sub.DisplayName == name { + return &subscription{ + SubscriptionId: sub.SubscriptionId, + SubscriptionFQID: sub.Id, + DisplayName: sub.DisplayName, + }, nil + } + } + + return nil, fmt.Errorf("subscription %s not found", name) +} diff --git a/internal/api/azure/tag.go b/internal/api/azure/tag.go new file mode 100644 index 0000000..4908020 --- /dev/null +++ b/internal/api/azure/tag.go @@ -0,0 +1,143 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// setTags sets the tags for the specified scope. +func (g *managementClient) setTags(scope string, tags map[string]string) error { + err := g.refreshToken() + if err != nil { + return err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + requestBody := struct { + Properties struct { + Tags map[string]string `json:"tags"` + } `json:"properties"` + }{} + requestBody.Properties.Tags = tags + payloadBytes, err := json.Marshal(requestBody) + if err != nil { + return err + } + + req, err := http.NewRequest( + "PUT", + fmt.Sprintf( + "https://management.azure.com/%s/providers/Microsoft.Resources/tags/default?api-version=2021-04-01", + strings.Trim(scope, "/")), + bytes.NewReader(payloadBytes)) + if err != nil { + return err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + req.Header.Add("Content-Type", "application/json") + + response, err := restClient.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + responseData, err := io.ReadAll(response.Body) + if err != nil { + return err + } + + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return err + } + + return fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } + + return nil +} + +// updateTags updates (or delete) the tags for the specified scope. +func (g *managementClient) updateTags(scope string, tags map[string]string, operation string) error { + err := g.refreshToken() + if err != nil { + return err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + requestBody := struct { + Properties struct { + Tags map[string]string `json:"tags"` + } `json:"properties"` + Operation string `json:"operation"` + }{} + requestBody.Operation = operation + requestBody.Properties.Tags = tags + payloadBytes, err := json.Marshal(requestBody) + if err != nil { + return err + } + + req, err := http.NewRequest( + "PATCH", + fmt.Sprintf( + "https://management.azure.com/%s/providers/Microsoft.Resources/tags/default?api-version=2021-04-01", + strings.Trim(scope, "/")), + bytes.NewReader(payloadBytes)) + if err != nil { + return err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + req.Header.Add("Content-Type", "application/json") + + response, err := restClient.Do(req) + if err != nil { + return err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + responseData, err := io.ReadAll(response.Body) + if err != nil { + return err + } + + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + return err + } + + return fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } + + return nil +} diff --git a/internal/api/azure/user.go b/internal/api/azure/user.go new file mode 100644 index 0000000..6de9bf4 --- /dev/null +++ b/internal/api/azure/user.go @@ -0,0 +1,92 @@ +package azure + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type user struct { + DisplayName string + UserPrincipalName string + Id string +} + +// getUser retrieves user information. +func (g *graphClient) getUser(spName string) (*user, error) { + err := g.refreshToken() + if err != nil { + return nil, err + } + + restClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest( + "GET", + fmt.Sprintf( + "https://graph.microsoft.com/v1.0/users('%s')?$select=displayName,userPrincipalName,id", + spName), + nil) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+g.token.AccessToken) + + response, err := restClient.Do(req) + if err != nil { + return nil, err + } + defer response.Body.Close() + + responseData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + switch response.StatusCode { + case http.StatusOK: + userDetails := struct { + DisplayName string `json:"displayName"` + UserPrincipalName string `json:"userPrincipalName"` + Id string `json:"id"` + }{} + err = json.Unmarshal(responseData, &userDetails) + if err != nil { + return nil, err + } + + return &user{ + DisplayName: userDetails.DisplayName, + UserPrincipalName: userDetails.UserPrincipalName, + Id: userDetails.Id, + }, nil + + case http.StatusAccepted: + // It's not clear what to do in this case. Graph API documentation + // does not provide much information about this status code. So just + // return nil and an error. + return nil, fmt.Errorf( + "request was accepted by the Azure Graph API but no data"+ + "was returned for ServicePrincipal %s", spName) + + default: + errorResponse := struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + }{} + err = json.Unmarshal(responseData, &errorResponse) + if err != nil { + panic(err) + } + + return nil, fmt.Errorf("error: %s, %s", + errorResponse.Error.Code, + errorResponse.Error.Message) + } +} diff --git a/internal/models/azure_sandbox.go b/internal/models/azure_sandbox.go new file mode 100644 index 0000000..ef84baa --- /dev/null +++ b/internal/models/azure_sandbox.go @@ -0,0 +1,510 @@ +package models + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "os" + "sync" + "time" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/rhpds/sandbox/internal/api/azure" + "github.com/rhpds/sandbox/internal/log" +) + +const ( + subscriptionNamePrefix = "pool-01-" + subscriptionCount = 10 + + // Sand box can be in state when deletion is not possible + // (e.g initializating). Those two constants controls + // how long delay process will last until error occurs + // up to deleteMaxRetries * deleteRetryDelay seconds + deleteMaxRetries = 10 + deleteRetryDelay = 5 +) + +type AzureSandboxProvider struct { + dbPool *pgxpool.Pool + vaultSecret string + + azureTenantId string + azureClientId string + azureSecret string + azurePoolApiSecret string + + poolMutex sync.Mutex +} + +type AzureSandboxWithCreds struct { + AzureSandbox + + Credentials []any `json:"credentials"` + Provider *AzureSandboxProvider `json:"-"` +} + +type AzureSandbox struct { + Id int `json:"id,omitempty"` + Name string `json:"name"` + Kind string `json:"kind"` // AzureSandbox + ServiceUuid string `json:"service_uuid"` + Status string `json:"status"` + CleanupCount int `json:"cleanup_count"` + Annotations Annotations `json:"annotations"` + ToCleanup bool `json:"to_cleanup"` + SubscriptionName string `json:"subscription_name"` + SubscriptionId string `json:"subscription_id"` + ResourceGroupName string `json:"resource_group_name"` + AppID string `json:"app_id"` + DisplayName string `json:"display_name"` +} + +func NewAzureSandboxProvider( + dbPool *pgxpool.Pool, + vaultSecret string, +) (*AzureSandboxProvider, error) { + provider := &AzureSandboxProvider{ + dbPool: dbPool, + vaultSecret: vaultSecret, + } + + if provider.azureTenantId = os.Getenv("AZURE_TENANT_ID"); provider.azureTenantId == "" { + return nil, fmt.Errorf("AZURE_TENANT_ID is not set") + } + + if provider.azureClientId = os.Getenv("AZURE_CLIENT_ID"); provider.azureClientId == "" { + return nil, fmt.Errorf("AZURE_CLIENT_ID is not set") + } + + if provider.azureSecret = os.Getenv("AZURE_SECRET"); provider.azureSecret == "" { + return nil, fmt.Errorf("AZURE_SECRET is not set") + } + + if provider.azurePoolApiSecret = os.Getenv("AZURE_POOL_API_SECRET"); provider.azurePoolApiSecret == "" { + return nil, fmt.Errorf("AZURE_POOL_API_SECRET is not set") + } + + return provider, nil +} + +func (a *AzureSandboxProvider) allocateSubscription() (string, error) { + SubscriptionNames := map[string]bool{} + + // Subscription names are not defined but used to get Subscription ID + // using Azure API calls. For simplicity we are using the subscriptionCount + // subscriptions starting from pool-01-001. + for i := 1; i <= subscriptionCount; i++ { + SubscriptionNames[fmt.Sprintf("%s%d", subscriptionNamePrefix, i)] = false + } + + rows, err := a.dbPool.Query( + context.Background(), + `SELECT resource_data ->> 'subscription_name' FROM resources WHERE status = 'success'`, + ) + if err != nil { + return "", fmt.Errorf("can't get retrieve about allocated pools") + } + defer rows.Close() + + allocatedSubscriptions := []string{} + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return "", fmt.Errorf("illegal pool name retrieved: %w", err) + } + allocatedSubscriptions = append(allocatedSubscriptions, name) + } + if err = rows.Err(); err != nil { + return "", fmt.Errorf("can't get allocated pool names: %w", err) + } + + for _, name := range allocatedSubscriptions { + if _, exists := SubscriptionNames[name]; exists { + SubscriptionNames[name] = true + } else { + log.Logger.Warn("Incorrect pool name found", "warning", name) + continue + } + } + + availableSubscriptions := make([]string, 0, len(SubscriptionNames)) + for k, v := range SubscriptionNames { + if !v { + availableSubscriptions = append(availableSubscriptions, k) + } + } + + if len(availableSubscriptions) == 0 { + return "", fmt.Errorf("no available pools") + } + + return availableSubscriptions[rand.Intn(len(availableSubscriptions))], nil +} + +func (a *AzureSandboxProvider) getNewSandboxName(guid string, serviceUuid string) (string, error) { + if guid == "" || serviceUuid == "" { + return "", fmt.Errorf("guid or serviceUuid is invalid") + } + + return fmt.Sprintf("%s-1-%s", guid, serviceUuid), nil +} + +func (a *AzureSandboxProvider) initNewAzureSandbox(serviceUuid string, annotations Annotations) (*AzureSandboxWithCreds, error) { + // Multiple Azure sandboxes can be initialize concurently + // and we should be sure that we are getting correct values + // for the new AzureSandboxWithCreds structure + a.poolMutex.Lock() + defer a.poolMutex.Unlock() + + azureSandbox := AzureSandboxWithCreds{ + AzureSandbox: AzureSandbox{ + Name: "noname", + Kind: "AzureSandbox", + ServiceUuid: serviceUuid, + Annotations: annotations, + Status: "initializing", + }, + Provider: a, + } + + sandboxName, err := a.getNewSandboxName( + annotations["guid"], + serviceUuid, + ) + if err != nil { + return nil, err + } + azureSandbox.AzureSandbox.Name = sandboxName + + subscriptionName, err := a.allocateSubscription() + if err != nil { + return nil, err + } + + azureSandbox.SubscriptionName = subscriptionName + + err = azureSandbox.Save() + if err != nil { + return nil, err + } + + return &azureSandbox, nil +} + +func (a *AzureSandboxProvider) Request( + serviceUuid string, + annotations Annotations, +) (AzureSandboxWithCreds, error) { + azureSandbox, err := a.initNewAzureSandbox(serviceUuid, annotations) + if err != nil { + log.Logger.Error("Can't init new Azure sandbox", "error", err) + return AzureSandboxWithCreds{}, err + } + + // Create the sandbox asynchronously + go azureSandbox.Create() + + return *azureSandbox, nil +} + +func (a *AzureSandboxProvider) FetchAllByServiceUuidWithCreds(serviceUuid string) ([]AzureSandboxWithCreds, error) { + sandboxes := []AzureSandboxWithCreds{} + // Get resource from above 'resources' table + rows, err := a.dbPool.Query( + context.Background(), + `SELECT + resource_data, + id, + resource_name, + resource_type, + status, + cleanup_count, + pgp_sym_decrypt(resource_credentials, $2) + FROM + resources + WHERE service_uuid = $1 AND resource_type = 'AzureSandbox'`, + serviceUuid, a.vaultSecret, + ) + if err != nil { + fmt.Printf("\n\nSQL error: %s\n\n", err) + if err == pgx.ErrNoRows { + log.Logger.Info("No account found", "service_uuid", serviceUuid) + } + return sandboxes, err + } + + for rows.Next() { + var sandbox AzureSandboxWithCreds + + creds := "" + if err := rows.Scan( + &sandbox, + &sandbox.Id, + &sandbox.Name, + &sandbox.Kind, + &sandbox.Status, + &sandbox.CleanupCount, + &creds, + ); err != nil { + return sandboxes, err + } + + // Unmarshal creds into account.Credentials + if err := json.Unmarshal([]byte(creds), &sandbox.Credentials); err != nil { + return sandboxes, err + } + + sandbox.ServiceUuid = serviceUuid + + sandboxes = append(sandboxes, sandbox) + } + + return sandboxes, nil +} + +func (a *AzureSandboxProvider) Release(serviceUuid string) error { + sandboxes, err := a.FetchAllByServiceUuidWithCreds(serviceUuid) + if err != nil { + return err + } + var errorHappened error + + for _, sandbox := range sandboxes { + sandbox.Provider = a + if err := sandbox.Delete(); err != nil { + errorHappened = err + continue + } + } + return errorHappened +} + +func (sb *AzureSandboxWithCreds) Update() error { + if sb.Id == 0 { + return fmt.Errorf("failed to update resources, Id is not set") + } + + credentials, err := json.Marshal(sb.Credentials) + if err != nil { + return fmt.Errorf("failed to marshal credentials: %w", err) + } + + sb.Credentials = []any{} + + _, err = sb.Provider.dbPool.Exec( + context.Background(), + `UPDATE resources + SET resource_name = $1, + resource_type = $2, + service_uuid = $3, + resource_data = $4, + resource_credentials = pgp_sym_encrypt($5::text, $6), + status = $7, + cleanup_count = $8 + WHERE id = $9`, + sb.Name, + sb.Kind, + sb.ServiceUuid, + sb, + credentials, + sb.Provider.vaultSecret, + sb.Status, + sb.CleanupCount, + sb.Id, + ) + if err != nil { + return fmt.Errorf("failed to update resource: %w", err) + } + + return nil +} + +func (sb *AzureSandboxWithCreds) Save() error { + if sb.Id != 0 { + return sb.Update() + } + + credentials, err := json.Marshal(sb.Credentials) + if err != nil { + return fmt.Errorf("failed to marshal credentials: %w", err) + } + + sb.Credentials = []any{} + + err = sb.Provider.dbPool.QueryRow( + context.Background(), + `INSERT INTO resources + (resource_name, resource_type, service_uuid, to_cleanup, resource_data, resource_credentials, status, cleanup_count) + VALUES ($1, $2, $3, $4, $5, pgp_sym_encrypt($6::text, $7), $8, $9) RETURNING id`, + sb.Name, + sb.Kind, + sb.ServiceUuid, + sb.ToCleanup, + sb, + credentials, + sb.Provider.vaultSecret, + sb.Status, + sb.CleanupCount, + ).Scan(&sb.Id) + if err != nil { + return fmt.Errorf("failed to insert resource: %w", err) + } + + return nil +} + +func (sb *AzureSandboxWithCreds) setStatus(status string) error { + _, err := sb.Provider.dbPool.Exec( + context.Background(), + fmt.Sprintf(`UPDATE resources + SET status = $1, + resource_data['status'] = to_jsonb('%s'::text) + WHERE id = $2`, status), + status, sb.Id, + ) + + return err +} + +func (sb *AzureSandboxWithCreds) getStatus() (string, error) { + var status string + err := sb.Provider.dbPool.QueryRow( + context.Background(), + "SELECT status FROM resources WHERE id = $1", + sb.Id, + ).Scan(&status) + + return status, err +} + +func (sb *AzureSandboxWithCreds) markForCleanup() error { + _, err := sb.Provider.dbPool.Exec( + context.Background(), + "UPDATE resources SET to_cleanup = true, resource_data['to_cleanup'] = 'true' where id = $1", + sb.Id, + ) + + return err +} + +func (sb *AzureSandboxWithCreds) Create() { + sandboxInfo, err := sb.requestAzureSandbox() + if err == nil { + sb.SubscriptionId = sandboxInfo.SubscriptionId + sb.ResourceGroupName = sandboxInfo.ResourceGroupName + sb.AppID = sandboxInfo.AppID + sb.DisplayName = sandboxInfo.DisplayName + sb.Credentials = []any{ + map[string]string{ + "password": sandboxInfo.Password, + }, + } + + sb.Status = "success" + } else { + log.Logger.Error("can't create Azure sandbox", "error", err, "name", sb.Name) + sb.Status = "error" + } + + err = sb.Save() + if err != nil { + log.Logger.Error("can't update Azure Sandbox status", "error", err) + return + } +} + +// models.Deletable interface implementation +func (sb *AzureSandboxWithCreds) Delete() error { + retryCount := deleteMaxRetries + for { + if retryCount == 0 { + err := fmt.Errorf("timeout error") + log.Logger.Error("can't delete resource", "error", err) + return err + } + + sandboxStatus, err := sb.getStatus() + if err != nil { + log.Logger.Error("can't get status of resource", "error", err, "name", sb.Name) + return err + } + + if sandboxStatus == "deleting" { + return nil + } + + if sandboxStatus == "success" || sandboxStatus == "error" { + break + } + + time.Sleep(deleteRetryDelay * time.Second) + retryCount-- + } + + sb.setStatus("deleting") + sb.markForCleanup() + + err := sb.cleanupAzureSandbox() + if err != nil { + log.Logger.Error("can't delete Azure resources", "error", err, "name", sb.Name) + sb.setStatus("error") + return err + } + + _, err = sb.Provider.dbPool.Exec( + context.Background(), + `DELETE FROM resources WHERE id = $1`, + sb.Id, + ) + if err != nil { + return fmt.Errorf("failed to remove resource: %w", err) + } + + return nil +} + +func (sb *AzureSandboxWithCreds) requestAzureSandbox() (*azure.SandboxInfo, error) { + sandboxClient := azure.InitSandboxClient( + azure.AzureCredentials{ + TenantID: sb.Provider.azureTenantId, + ClientID: sb.Provider.azureClientId, + Secret: sb.Provider.azureSecret, + }, + ) + + sandboxInfo, err := sandboxClient.CreateSandboxEnvironment( + sb.SubscriptionName, + sb.Annotations["requester"], + sb.Annotations["guid"], + sb.Annotations["cost_center"], + sb.Annotations["domain"], + ) + if err != nil { + return nil, err + } + + return sandboxInfo, nil +} + +func (sb *AzureSandboxWithCreds) cleanupAzureSandbox() error { + sandboxClient := azure.InitSandboxClient( + azure.AzureCredentials{ + TenantID: sb.Provider.azureTenantId, + ClientID: sb.Provider.azureClientId, + Secret: sb.Provider.azureSecret, + }, + ) + + err := sandboxClient.CleanupSandboxEnvironment( + sb.SubscriptionName, + sb.Annotations["guid"], + ) + if err != nil { + return err + } + + return nil +} diff --git a/internal/models/ocp_sandbox.go b/internal/models/ocp_sandbox.go index 24805cb..c20a27e 100644 --- a/internal/models/ocp_sandbox.go +++ b/internal/models/ocp_sandbox.go @@ -468,7 +468,6 @@ func (p *OcpSandboxProvider) GetOcpSharedClusterConfigurations() (OcpSharedClust FROM ocp_shared_cluster_configurations`, p.VaultSecret, ) - if err != nil { if err == pgx.ErrNoRows { log.Logger.Info("No cluster found") @@ -519,7 +518,6 @@ func (p *OcpSandboxProvider) GetOcpSharedClusterConfigurationByAnnotations(annot `SELECT name FROM ocp_shared_cluster_configurations WHERE annotations @> $1`, annotations, ) - if err != nil { if err == pgx.ErrNoRows { log.Logger.Info("No cluster found", "annotations", annotations) @@ -570,7 +568,6 @@ func (a *OcpSandbox) Save(dbpool *pgxpool.Pool) error { } func (a *OcpSandboxWithCreds) Update() error { - if a.ID == 0 { return errors.New("id must be > 0") } @@ -672,6 +669,7 @@ func (a *OcpSandboxWithCreds) IncrementCleanupCount() error { return err } + func (a *OcpSandboxProvider) FetchAllByServiceUuid(serviceUuid string) ([]OcpSandbox, error) { accounts := []OcpSandbox{} // Get resource from above 'resources' table @@ -694,7 +692,6 @@ func (a *OcpSandboxProvider) FetchAllByServiceUuid(serviceUuid string) ([]OcpSan WHERE r.service_uuid = $1`, serviceUuid, ) - if err != nil { if err == pgx.ErrNoRows { log.Logger.Info("No account found", "service_uuid", serviceUuid) @@ -745,10 +742,9 @@ func (a *OcpSandboxProvider) FetchAllByServiceUuidWithCreds(serviceUuid string) resources r LEFT JOIN ocp_shared_cluster_configurations oc ON oc.name = r.resource_data->>'ocp_cluster' - WHERE r.service_uuid = $1`, + WHERE r.service_uuid = $1 AND r.resource_type = 'OcpSandbox'`, serviceUuid, a.VaultSecret, ) - if err != nil { if err == pgx.ErrNoRows { log.Logger.Info("No account found", "service_uuid", serviceUuid) @@ -798,7 +794,6 @@ func (a *OcpSandboxProvider) GetSchedulableClusters(cloud_selector map[string]st `SELECT name FROM ocp_shared_cluster_configurations WHERE annotations @> $1 and valid=true ORDER BY random()`, cloud_selector, ) - if err != nil { if err == pgx.ErrNoRows { log.Logger.Info("No cluster found", "cloud_selector", cloud_selector) @@ -1023,7 +1018,6 @@ func (a *OcpSandboxProvider) Request(serviceUuid string, cloud_selector map[stri nodeMetric, err := clientsetMetrics.MetricsV1beta1(). NodeMetricses(). Get(context.Background(), node.Name, metav1.GetOptions{}) - if err != nil { log.Logger.Error( "Error Get OCP node metrics v1beta1, ignore the node", @@ -1115,7 +1109,6 @@ func (a *OcpSandboxProvider) Request(serviceUuid string, cloud_selector map[stri }, }, }, metav1.CreateOptions{}) - if err != nil { if strings.Contains(err.Error(), "object is being deleted: namespace") { log.Logger.Warn("Error creating OCP namespace", "error", err) @@ -1235,7 +1228,6 @@ func (a *OcpSandboxProvider) Request(serviceUuid string, cloud_selector map[stri }, }, }, metav1.CreateOptions{}) - if err != nil { log.Logger.Error("Error creating OCP service account", "error", err) // Delete the namespace @@ -1268,7 +1260,6 @@ func (a *OcpSandboxProvider) Request(serviceUuid string, cloud_selector map[stri }, }, }, metav1.CreateOptions{}) - if err != nil { log.Logger.Error("Error creating OCP RoleBind", "error", err) if err := clientset.CoreV1().Namespaces().Delete(context.TODO(), namespaceName, metav1.DeleteOptions{}); err != nil { @@ -1505,7 +1496,6 @@ func guessNextGuid(origGuid string, serviceUuid string, dbpool *pgxpool.Pool, mu AND resource_type = 'OcpSandbox'`, candidateName, ).Scan(&rowcount) - if err != nil { return "", err } @@ -1521,7 +1511,6 @@ func guessNextGuid(origGuid string, serviceUuid string, dbpool *pgxpool.Pool, mu func (a *OcpSandboxProvider) Release(service_uuid string) error { accounts, err := a.FetchAllByServiceUuidWithCreds(service_uuid) - if err != nil { return err } @@ -1572,7 +1561,6 @@ func (a *OcpSandboxProvider) FetchAll() ([]OcpSandbox, error) { FROM resources r LEFT JOIN ocp_shared_cluster_configurations oc ON oc.name = r.resource_data->>'ocp_cluster'`, ) - if err != nil { if err == pgx.ErrNoRows { log.Logger.Info("No account found") @@ -1603,7 +1591,6 @@ func (a *OcpSandboxProvider) FetchAll() ([]OcpSandbox, error) { } func (account *OcpSandboxWithCreds) Delete() error { - if account.ID == 0 { return errors.New("resource ID must be > 0") } @@ -1657,7 +1644,6 @@ func (account *OcpSandboxWithCreds) Delete() error { "SELECT resource_data->>'ocp_cluster' FROM resources WHERE id = $1", account.ID, ).Scan(&account.OcpSharedClusterConfigurationName) - if err != nil { if err == pgx.ErrNoRows { log.Logger.Error("Ocp cluster doesn't exist for resource", "name", account.Name) diff --git a/internal/models/placements.go b/internal/models/placements.go index 15caa2b..5849446 100644 --- a/internal/models/placements.go +++ b/internal/models/placements.go @@ -41,9 +41,7 @@ func (p *Placement) Render(w http.ResponseWriter, r *http.Request) error { } func (p *Placement) LoadResources(awsProvider AwsAccountProvider, ocpProvider OcpSandboxProvider) error { - accounts, err := awsProvider.FetchAllByServiceUuid(p.ServiceUuid) - if err != nil { return err } @@ -58,7 +56,6 @@ func (p *Placement) LoadResources(awsProvider AwsAccountProvider, ocpProvider Oc status := "success" ocpSandboxes, err := ocpProvider.FetchAllByServiceUuid(p.ServiceUuid) - if err != nil { return err } @@ -86,9 +83,7 @@ func (p *Placement) LoadResources(awsProvider AwsAccountProvider, ocpProvider Oc } func (p *Placement) LoadResourcesWithCreds(awsProvider AwsAccountProvider, ocpProvider OcpSandboxProvider) error { - accounts, err := awsProvider.FetchAllByServiceUuidWithCreds(p.ServiceUuid) - if err != nil { return err } @@ -102,7 +97,6 @@ func (p *Placement) LoadResourcesWithCreds(awsProvider AwsAccountProvider, ocpPr status := "success" ocpSandboxes, err := ocpProvider.FetchAllByServiceUuidWithCreds(p.ServiceUuid) - if err != nil { return err } @@ -131,7 +125,6 @@ func (p *Placement) LoadResourcesWithCreds(awsProvider AwsAccountProvider, ocpPr func (p *Placement) LoadActiveResources(awsProvider AwsAccountProvider) error { accounts, err := awsProvider.FetchAllActiveByServiceUuid(p.ServiceUuid) - if err != nil { return err } @@ -148,9 +141,7 @@ func (p *Placement) LoadActiveResources(awsProvider AwsAccountProvider) error { } func (p *Placement) LoadActiveResourcesWithCreds(awsProvider AwsAccountProvider, ocpProvider OcpSandboxProvider) error { - accounts, err := awsProvider.FetchAllActiveByServiceUuidWithCreds(p.ServiceUuid) - if err != nil { return err } @@ -164,7 +155,6 @@ func (p *Placement) LoadActiveResourcesWithCreds(awsProvider AwsAccountProvider, status := "success" ocpSandboxes, err := ocpProvider.FetchAllByServiceUuidWithCreds(p.ServiceUuid) - if err != nil { return err } @@ -216,7 +206,6 @@ func (p *Placement) Create() error { VALUES ($1, $2, $3) RETURNING id`, p.ServiceUuid, p.Request, p.Annotations, ).Scan(&id) - if err != nil { return err } @@ -238,7 +227,7 @@ func (p *Placement) Create() error { } // Delete deletes a placement -func (p *Placement) Delete(accountProvider AwsAccountProvider, ocpProvider OcpSandboxProvider) { +func (p *Placement) Delete(accountProvider AwsAccountProvider, ocpProvider OcpSandboxProvider, azureProvider *AzureSandboxProvider) { if err := p.SetStatus("deleting"); err != nil { log.Logger.Error("error setting status for placement", "serviceUuid", p.ServiceUuid, @@ -265,11 +254,16 @@ func (p *Placement) Delete(accountProvider AwsAccountProvider, ocpProvider OcpSa return } + if err := azureProvider.Release(p.ServiceUuid); err != nil { + log.Logger.Error("Error while releasing Azure sandboxes") + p.SetStatus("error") + return + } + _, err := p.DbPool.Exec( context.Background(), "DELETE FROM placements WHERE id = $1", p.ID, ) - if err != nil { p.SetStatus("error") return @@ -300,7 +294,6 @@ func (p *Placement) GetLastStatus() ([]*LifecycleResourceJob, error) { ORDER BY updated_at DESC LIMIT 1`, p.ID, ).Scan(&id) - if err != nil { return nil, err } @@ -313,7 +306,6 @@ func (p *Placement) GetLastStatus() ([]*LifecycleResourceJob, error) { ORDER BY updated_at`, id, ) - if err != nil { return nil, err } @@ -330,7 +322,6 @@ func (p *Placement) GetLastStatus() ([]*LifecycleResourceJob, error) { } job, err := GetLifecycleResourceJob(p.DbPool, idR) - if err != nil { return result, err } @@ -371,7 +362,6 @@ func GetPlacement(dbpool *pgxpool.Pool, id int) (*Placement, error) { &p.ToCleanup, &p.CreatedAt, &p.UpdatedAt) - if err != nil { return nil, err } @@ -397,7 +387,6 @@ func GetAllPlacements(dbpool *pgxpool.Pool) (Placements, error) { updated_at FROM placements`, ) - if err != nil { return nil, err } @@ -456,7 +445,6 @@ func GetPlacementByServiceUuid(dbpool *pgxpool.Pool, serviceUuid string) (*Place &p.ToCleanup, &p.CreatedAt, &p.UpdatedAt) - if err != nil { return nil, err } @@ -466,7 +454,7 @@ func GetPlacementByServiceUuid(dbpool *pgxpool.Pool, serviceUuid string) (*Place } // DeletePlacementByServiceUuid deletes a placement by ServiceUuid -func DeletePlacementByServiceUuid(dbpool *pgxpool.Pool, awsProvider AwsAccountProvider, ocpProvider OcpSandboxProvider, serviceUuid string) error { +func DeletePlacementByServiceUuid(dbpool *pgxpool.Pool, awsProvider AwsAccountProvider, ocpProvider OcpSandboxProvider, azureProvider *AzureSandboxProvider, serviceUuid string) error { placement, err := GetPlacementByServiceUuid(dbpool, serviceUuid) if err != nil { return err @@ -479,7 +467,7 @@ func DeletePlacementByServiceUuid(dbpool *pgxpool.Pool, awsProvider AwsAccountPr return err } - go placement.Delete(awsProvider, ocpProvider) + go placement.Delete(awsProvider, ocpProvider, azureProvider) return nil } @@ -491,7 +479,6 @@ func (p *Placement) SetStatus(status string) error { status, p.ID, ) - if err != nil { log.Logger.Error("Error setting status", "error", err) return err @@ -508,7 +495,6 @@ func (p *Placement) MarkForCleanup() error { "UPDATE placements SET to_cleanup = true WHERE id = $1", p.ID, ) - if err != nil { return err } diff --git a/tests/006_azure.hurl b/tests/006_azure.hurl new file mode 100644 index 0000000..96d5de3 --- /dev/null +++ b/tests/006_azure.hurl @@ -0,0 +1,66 @@ +################################################################################# +# Get an access token using the login token +################################################################################# + +GET {{host}}/api/v1/login +Authorization: Bearer {{login_token}} +HTTP 200 +[Captures] +access_token: jsonpath "$.access_token" +[Asserts] +jsonpath "$.access_token" isString +jsonpath "$.access_token_exp" isString + +################################################################################# +# Ensure placement doesn't exist +################################################################################# + +GET {{host}}/api/v1/placements/{{uuid}} +Authorization: Bearer {{access_token}} +[Options] +retry: 10 +HTTP 404 + +################################################################################# +# Create a new placement +################################################################################# + +POST {{host}}/api/v1/placements +Authorization: Bearer {{access_token}} +{ + "service_uuid": "{{uuid}}", + "resources": [ + { + "kind": "AzureSandbox", + "annotations": { + "purpose": "backend", + "requester": "{{requester}}", + "cost_center": "{{cost_center}}", + "domain": "{{domain}}" + } + } + ], + "annotations": { + "guid": "st6zn" + } +} +HTTP 200 +[Captures] +account: jsonpath "$.Placement.resources[0].subscription_name" +[Asserts] +jsonpath "$.message" == "Placement Created" +jsonpath "$.Placement.service_uuid" == "{{uuid}}" +jsonpath "$.Placement.resources" count == 1 +jsonpath "$.Placement.resources[0].annotations.guid" == "st6zn" +jsonpath "$.Placement.resources[0].annotations.purpose" == "backend" + + +################################################################################# +# Get placement and ensure credentials are present +################################################################################# + +GET {{host}}/api/v1/placements/{{uuid}} +Authorization: Bearer {{access_token}} +HTTP 200 +[Asserts] +jsonpath "$.resources[0].credentials" count >= 1 diff --git a/tests/readme.adoc b/tests/readme.adoc index de6f0ef..e26965a 100644 --- a/tests/readme.adoc +++ b/tests/readme.adoc @@ -105,3 +105,23 @@ hurl --variable login_token_admin=$admintoken \ === Troubleshoot === Add the `--verbose` argument to the `hurl` command to see the full requests. + +== Azure tests == + +Variables 'requester', 'cost_center', 'domain' need to be set +---- +export HURL_requester="...@redhat.com" +export HURL_pool_id="..." +export HURL_cost_center="..." +export HURL_domain="..." + + +hurl --variable login_token_admin=$admintoken \ +--variable login_token=$apptoken \ +--variable host=http://localhost:8080 \ +--variable requester=${HURL_requester} \ +--variable cost_center=${HURL_cost_center} \ +--variable domain=${HURL_domain} \ +--variable uuid=$uuid \ +006_azure.hurl --test +----