From cec630746b9aab3f46a1838b9257463e76fbc21e Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Mon, 22 Jul 2024 21:17:35 +0000 Subject: [PATCH 01/12] adding sqlcmd tool support --- cli/azd/.vscode/cspell-azd-dictionary.txt | 3 +- cli/azd/pkg/tools/sqlcmd/sqlcmd.go | 376 ++++++++++++++++++++++ cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go | 191 +++++++++++ go.mod | 1 + go.sum | 6 + 5 files changed, 576 insertions(+), 1 deletion(-) create mode 100644 cli/azd/pkg/tools/sqlcmd/sqlcmd.go create mode 100644 cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go diff --git a/cli/azd/.vscode/cspell-azd-dictionary.txt b/cli/azd/.vscode/cspell-azd-dictionary.txt index 1834e4f1526..492ae823395 100644 --- a/cli/azd/.vscode/cspell-azd-dictionary.txt +++ b/cli/azd/.vscode/cspell-azd-dictionary.txt @@ -179,6 +179,7 @@ servicebus setenvs snapshotter springapp +sqlcmd sqlserver sstore staticcheck @@ -216,4 +217,4 @@ webfrontend westus2 wireinject yacspin -zerr +zerr \ No newline at end of file diff --git a/cli/azd/pkg/tools/sqlcmd/sqlcmd.go b/cli/azd/pkg/tools/sqlcmd/sqlcmd.go new file mode 100644 index 00000000000..6cf2a921f7e --- /dev/null +++ b/cli/azd/pkg/tools/sqlcmd/sqlcmd.go @@ -0,0 +1,376 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package sqlcmd + +import ( + "archive/tar" + "archive/zip" + "compress/bzip2" + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/azure/azure-dev/cli/azd/internal/tracing" + "github.com/azure/azure-dev/cli/azd/internal/tracing/events" + "github.com/azure/azure-dev/cli/azd/pkg/config" + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/osutil" + "github.com/azure/azure-dev/cli/azd/pkg/tools" + "github.com/blang/semver/v4" +) + +// sqlCmdCliVersion is the minimum version of sqlCmd cli that we require +var sqlCmdCliVersion semver.Version = semver.MustParse("1.8.0") + +func NewSqlCmdCli(ctx context.Context, console input.Console, commandRunner exec.CommandRunner) (*SqlCmdCli, error) { + return newSqlCmdCliImplementation(ctx, console, commandRunner, http.DefaultClient, downloadSqlCmd, extractSqlCmdCli) +} + +// NewSqlCmdCliImplementation is like NewSqlCmdCli but allows providing a custom transport to use when downloading the +// sqlCmd CLI, for testing purposes. +func newSqlCmdCliImplementation( + ctx context.Context, + console input.Console, + commandRunner exec.CommandRunner, + transporter policy.Transporter, + acquireSqlCmdCliImpl getSqlCmdCliImplementation, + extractImplementation extractSqlCmdCliFromFileImplementation, +) (*SqlCmdCli, error) { + if override := os.Getenv("AZD_SQL_CMD_CLI_TOOL_PATH"); override != "" { + log.Printf("using external sqlCmd cli tool: %s", override) + cli := &SqlCmdCli{ + path: override, + commandRunner: commandRunner, + } + cli.logVersion(ctx) + + return cli, nil + } + + sqlCmdCliPath, err := azdSqlCmdCliPath() + if err != nil { + return nil, fmt.Errorf("getting sqlCmd cli default path: %w", err) + } + + if _, err = os.Stat(sqlCmdCliPath); err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("getting file information from sqlCmd cli default path: %w", err) + } + var installSqlCmdCli bool + if errors.Is(err, os.ErrNotExist) || !expectedVersionInstalled(ctx, commandRunner, sqlCmdCliPath) { + installSqlCmdCli = true + } + if installSqlCmdCli { + if err := os.MkdirAll(filepath.Dir(sqlCmdCliPath), osutil.PermissionDirectory); err != nil { + return nil, fmt.Errorf("creating sqlCmd cli default path: %w", err) + } + + msg := "setting up sqlCmd connection" + console.ShowSpinner(ctx, msg, input.Step) + err = acquireSqlCmdCliImpl(ctx, transporter, sqlCmdCliVersion, extractImplementation, sqlCmdCliPath) + console.StopSpinner(ctx, "", input.Step) + if err != nil { + return nil, fmt.Errorf("setting up sqlCmd connection: %w", err) + } + } + + cli := &SqlCmdCli{ + path: sqlCmdCliPath, + commandRunner: commandRunner, + } + cli.logVersion(ctx) + return cli, nil +} + +func (cli *SqlCmdCli) logVersion(ctx context.Context) { + if ver, err := cli.extractVersion(ctx); err == nil { + log.Printf("sqlcmd cli version: %s", ver) + } else { + log.Printf("could not determine github cli version: %s", err) + } +} + +// extractVersion gets the version of the sqlCmd CLI, from the output of `sqlCmd --version` +func (cli *SqlCmdCli) extractVersion(ctx context.Context) (string, error) { + runArgs := cli.newRunArgs("--version") + res, err := cli.run(ctx, runArgs) + if err != nil { + return "", fmt.Errorf("error running sqlcmd --version: %w", err) + } + return res.Stdout, nil +} + +// azdSqlCmdCliPath returns the path where we store our local copy of sqlCmd cli ($AZD_CONFIG_DIR/bin). +func azdSqlCmdCliPath() (string, error) { + configDir, err := config.GetUserConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, "bin", sqlCmdCliName()), nil +} + +func sqlCmdCliName() string { + if runtime.GOOS == "windows" { + return "sqlcmd.exe" + } + return "sqlcmd" +} + +type SqlCmdCli struct { + commandRunner exec.CommandRunner + path string +} + +func expectedVersionInstalled(ctx context.Context, commandRunner exec.CommandRunner, binaryPath string) bool { + sqlCmdVersion, err := tools.ExecuteCommand(ctx, commandRunner, binaryPath, "--version") + if err != nil { + log.Printf("checking %s version: %s", sqlCmdToolName, err.Error()) + return false + } + sqlCmdSemver, err := tools.ExtractVersion(sqlCmdVersion) + if err != nil { + log.Printf("converting to semver version fails: %s", err.Error()) + return false + } + if sqlCmdSemver.LT(sqlCmdCliVersion) { + log.Printf("Found sqlCmd cli version %s. Expected version: %s.", sqlCmdSemver.String(), sqlCmdCliVersion.String()) + return false + } + return true +} + +const sqlCmdToolName = "sqlCmd CLI" + +func (cli *SqlCmdCli) Name() string { + return sqlCmdToolName +} + +func (cli *SqlCmdCli) BinaryPath() string { + return cli.path +} + +func (cli *SqlCmdCli) InstallUrl() string { + return "https://github.com/microsoft/go-sqlcmd" +} + +func (cli *SqlCmdCli) newRunArgs(args ...string) exec.RunArgs { + return exec.NewRunArgs(cli.path, args...) +} + +func (cli *SqlCmdCli) run(ctx context.Context, runArgs exec.RunArgs) (exec.RunResult, error) { + return cli.commandRunner.Run(ctx, runArgs) +} + +func extractFromZip(src, dst string) (string, error) { + zipReader, err := zip.OpenReader(src) + if err != nil { + return "", err + } + + log.Printf("extract from zip %s", src) + defer zipReader.Close() + + var extractedAt string + for _, file := range zipReader.File { + fileName := file.FileInfo().Name() + if !file.FileInfo().IsDir() && fileName == sqlCmdCliName() { + log.Printf("found cli at: %s", file.Name) + fileReader, err := file.Open() + if err != nil { + return extractedAt, err + } + filePath := filepath.Join(dst, fileName) + sqlCmdCliFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) + if err != nil { + return extractedAt, err + } + defer sqlCmdCliFile.Close() + /* #nosec G110 - decompression bomb false positive */ + _, err = io.Copy(sqlCmdCliFile, fileReader) + if err != nil { + return extractedAt, err + } + extractedAt = filePath + break + } + } + if extractedAt != "" { + log.Printf("extracted to: %s", extractedAt) + return extractedAt, nil + } + return extractedAt, fmt.Errorf("sqlCmd cli binary was not found within the zip file") +} + +func extractFromTar(src, dst string) (string, error) { + bz2File, err := os.Open(src) + if err != nil { + return "", err + } + defer bz2File.Close() + + bz2Reader := bzip2.NewReader(bz2File) + + var extractedAt string + // tarReader doesn't need to be closed as it is closed by the gz reader + tarReader := tar.NewReader(bz2Reader) + for { + fileHeader, err := tarReader.Next() + if errors.Is(err, io.EOF) { + return extractedAt, fmt.Errorf("did not find sqlcmd cli within tar file") + } + if fileHeader == nil { + continue + } + if err != nil { + return extractedAt, err + } + // Tha name contains the path, remove it + fileNameParts := strings.Split(fileHeader.Name, "/") + fileName := fileNameParts[len(fileNameParts)-1] + // cspell: disable-next-line `Typeflag` is comming fron *tar.Header + if fileHeader.Typeflag == tar.TypeReg && fileName == "sqlcmd" { + filePath := filepath.Join(dst, fileName) + sqlCmdCliFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(fileHeader.Mode)) + if err != nil { + return extractedAt, err + } + defer sqlCmdCliFile.Close() + /* #nosec G110 - decompression bomb false positive */ + _, err = io.Copy(sqlCmdCliFile, tarReader) + if err != nil { + return extractedAt, err + } + extractedAt = filePath + break + } + } + if extractedAt != "" { + return extractedAt, nil + } + return extractedAt, fmt.Errorf("extract from tar error. Extraction ended in unexpected state.") +} + +// extractSqlCmdCli gets the sqlCmd cli from either a zip or a tar.gz +func extractSqlCmdCli(src, dst string) (string, error) { + if strings.HasSuffix(src, ".zip") { + return extractFromZip(src, dst) + } else if strings.HasSuffix(src, ".tar.bz2") { + return extractFromTar(src, dst) + } + return "", fmt.Errorf("Unknown format while trying to extract") +} + +// getSqlCmdCliImplementation defines the contract function to acquire the sqlCmd cli. +// The `outputPath` is the destination where the sqlCmd cli is place it. +type getSqlCmdCliImplementation func( + ctx context.Context, + transporter policy.Transporter, + sqlCmdVersion semver.Version, + extractImplementation extractSqlCmdCliFromFileImplementation, + outputPath string) error + +// extractSqlCmdCliFromFileImplementation defines how the cli is extracted +type extractSqlCmdCliFromFileImplementation func(src, dst string) (string, error) + +// downloadSqlCmd downloads a given version of sqlCmd cli from the release site. +func downloadSqlCmd( + ctx context.Context, + transporter policy.Transporter, + sqlCmdVersion semver.Version, + extractImplementation extractSqlCmdCliFromFileImplementation, + path string) error { + + binaryName := func(platform string) string { + return fmt.Sprintf("sqlcmd-%s", platform) + } + + systemArch := runtime.GOARCH + // arm and x86 not supported (similar to bicep) + var releaseName string + switch runtime.GOOS { + case "windows": + releaseName = binaryName(fmt.Sprintf("windows-%s.zip", systemArch)) + case "darwin": + releaseName = binaryName(fmt.Sprintf("darwin-%s.tar.bz2", systemArch)) + case "linux": + releaseName = binaryName(fmt.Sprintf("linux-%s.tar.bz2", systemArch)) + default: + return fmt.Errorf("unsupported platform") + } + + sqlCmdReleaseUrl := fmt.Sprintf( + "https://github.com/microsoft/go-sqlcmd/releases/download/v%s/%s", sqlCmdVersion, releaseName) + + log.Printf("downloading sqlCmd cli release %s -> %s", sqlCmdReleaseUrl, releaseName) + + spanCtx, span := tracing.Start(ctx, events.SqlCmdCliInstallEvent) + defer span.End() + + req, err := http.NewRequestWithContext(spanCtx, "GET", sqlCmdReleaseUrl, nil) + if err != nil { + return err + } + + resp, err := transporter.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("http error %d", resp.StatusCode) + } + + tmpPath := filepath.Dir(path) + compressedRelease, err := os.CreateTemp(tmpPath, releaseName) + if err != nil { + return err + } + defer func() { + _ = compressedRelease.Close() + _ = os.Remove(compressedRelease.Name()) + }() + + if _, err := io.Copy(compressedRelease, resp.Body); err != nil { + return err + } + if err := compressedRelease.Close(); err != nil { + return err + } + + // change file name from temporal name to the final name, as the download has completed + compressedFileName := filepath.Join(tmpPath, releaseName) + if err := osutil.Rename(ctx, compressedRelease.Name(), compressedFileName); err != nil { + return err + } + defer func() { + log.Printf("delete %s", compressedFileName) + _ = os.Remove(compressedFileName) + }() + + // unzip downloaded file + log.Printf("extracting file %s", compressedFileName) + _, err = extractImplementation(compressedFileName, tmpPath) + if err != nil { + return err + } + + return nil +} + +func (cli *SqlCmdCli) ExecuteScript(ctx context.Context, server, dbName, path string, env []string) (string, error) { + runArgs := cli.newRunArgs("-G", "-l", "30", "-S", server, "-d", dbName, "-i", path).WithEnv(env) + res, err := cli.run(ctx, runArgs) + if err != nil { + return "", fmt.Errorf("error running sqlcmd: %w", err) + } + return res.Stdout, nil +} diff --git a/cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go b/cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go new file mode 100644 index 00000000000..d81bfaf0425 --- /dev/null +++ b/cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package sqlcmd + +import ( + "archive/tar" + "archive/zip" + "bytes" + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/osutil" + "github.com/azure/azure-dev/cli/azd/test/mocks" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockinput" + "github.com/dsnet/compress/bzip2" + + "github.com/stretchr/testify/require" +) + +func TestNewSqlCmdHubCli(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", configRoot) + + mockContext := mocks.NewMockContext(context.Background()) + + mockContext.HttpClient.When(func(request *http.Request) bool { + return request.Method == http.MethodGet && request.URL.Host == "github.com" + }).Respond(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("this is sqlcmd cli")), + }) + + mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" + }).Respond(exec.NewRunResult( + 0, + fmt.Sprintf("%s", sqlCmdCliVersion.String()), + "", + )) + + mockExtract := func(src, dst string) (string, error) { + exp, _ := azdSqlCmdCliPath() + _ = osutil.Rename(context.Background(), src, exp) + return src, nil + } + + cli, err := newSqlCmdCliImplementation( + *mockContext.Context, + mockContext.Console, + mockContext.CommandRunner, + mockContext.HttpClient, + downloadSqlCmd, + mockExtract, + ) + require.NoError(t, err) + require.NotNil(t, cli) + + require.Equal(t, 2, len(mockContext.Console.SpinnerOps())) + + require.Equal(t, mockinput.SpinnerOp{ + Op: mockinput.SpinnerOpShow, + Message: "setting up sqlCmd connection", + Format: input.Step, + }, mockContext.Console.SpinnerOps()[0]) + + sqlCmdCli, err := azdSqlCmdCliPath() + require.NoError(t, err) + + contents, err := os.ReadFile(sqlCmdCli) + require.NoError(t, err) + + require.Equal(t, []byte("this is sqlcmd cli"), contents) + + ver, err := cli.extractVersion(context.Background()) + require.NoError(t, err) + require.Equal(t, sqlCmdCliVersion.String(), ver) +} + +func TestZipExtractContents(t *testing.T) { + testPath := t.TempDir() + expectedPhrase := "this will be inside a zip file" + zipFilePath, err := createSampleZip(testPath, expectedPhrase, "bin/"+sqlCmdCliName()) + require.NoError(t, err) + ghCliPath, err := extractSqlCmdCli(zipFilePath, testPath) + require.NoError(t, err) + + content, err := os.ReadFile(ghCliPath) + require.NoError(t, err) + require.EqualValues(t, []byte(expectedPhrase), content) +} + +func TestTarExtractContents(t *testing.T) { + testPath := t.TempDir() + expectedPhrase := "this will be inside a tar file" + tarFilePath, err := createSampleTarBz2(testPath, expectedPhrase, "sqlcmd") + require.NoError(t, err) + ghCliPath, err := extractSqlCmdCli(tarFilePath, testPath) + require.NoError(t, err) + + content, err := os.ReadFile(ghCliPath) + require.NoError(t, err) + require.EqualValues(t, []byte(expectedPhrase), content) +} + +func createSampleZip(path, content, file string) (string, error) { + filePath := filepath.Join(path, "zippedFile.zip") + zipFile, err := os.Create(filePath) + if err != nil { + return "", err + } + defer zipFile.Close() + + contentReader := strings.NewReader(content) + zipWriter := zip.NewWriter(zipFile) + + zipContent, err := zipWriter.Create(file) + if err != nil { + return "", err + } + + if _, err := io.Copy(zipContent, contentReader); err != nil { + return "", err + } + + zipWriter.Close() + + return filePath, nil +} + +func createSampleTarBz2(path, content, file string) (string, error) { + filePath := filepath.Join(path, "zippedFile.tar.bz2") + tarFile, err := os.Create(filePath) + if err != nil { + return "", err + } + defer tarFile.Close() + + gzWriter, err := bzip2.NewWriter(tarFile, nil) + if err != nil { + return "", err + } + defer gzWriter.Close() + + tarWriter := tar.NewWriter(gzWriter) + defer tarWriter.Close() + + // not sure how tar from memory. Let's create an extra file with content + fileContentPath := filepath.Join(path, file) + fileContent, err := os.Create(fileContentPath) + if err != nil { + return "", err + } + if _, err := fileContent.WriteString(content); err != nil { + return "", err + } + fileContent.Close() + + // tar the file + fileInfo, err := os.Stat(fileContentPath) + if err != nil { + return "", err + } + tarHeader, err := tar.FileInfoHeader(fileInfo, fileInfo.Name()) + if err != nil { + return "", err + } + if err := tarWriter.WriteHeader(tarHeader); err != nil { + return "", nil + } + fileContent, err = os.Open(fileContentPath) + defer func() { + _ = fileContent.Close() + }() + if err != nil { + return "", err + } + if _, err := io.Copy(tarWriter, fileContent); err != nil { + return "", err + } + + return filePath, nil +} diff --git a/go.mod b/go.mod index 959329bace2..8d78eef713f 100644 --- a/go.mod +++ b/go.mod @@ -76,6 +76,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v0.8.0 // indirect github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dsnet/compress v0.0.1 github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect diff --git a/go.sum b/go.sum index d666fc23a2b..ff608c4eb25 100644 --- a/go.sum +++ b/go.sum @@ -178,6 +178,9 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/drone/envsubst v1.0.3 h1:PCIBwNDYjs50AsLZPYdfhSATKaRg/FJmDc2D6+C2x8g= github.com/drone/envsubst v1.0.3/go.mod h1:N2jZmlMufstn1KEqvbHjw40h1KyTmnVzHcSc9bFiJ2g= +github.com/dsnet/compress v0.0.1 h1:PlZu0n3Tuv04TzpfPbrnI0HW/YwodEXDS+oPKahKF0Q= +github.com/dsnet/compress v0.0.1/go.mod h1:Aw8dCMJ7RioblQeTqt88akK31OvO8Dhf5JflhBbQEHo= +github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -356,6 +359,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -501,6 +506,7 @@ github.com/tedsuo/ifrit v0.0.0-20180802180643-bea94bb476cc/go.mod h1:eyZnKCc955u github.com/theckman/yacspin v0.13.12 h1:CdZ57+n0U6JMuh2xqjnjRq5Haj6v1ner2djtLQRzJr4= github.com/theckman/yacspin v0.13.12/go.mod h1:Rd2+oG2LmQi5f3zC3yeZAOl245z8QOvrH4OPOJNZxLg= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= +github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= From e99fa21ec8071e3c264414f89368d6ebc37d7768 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Mon, 22 Jul 2024 21:17:50 +0000 Subject: [PATCH 02/12] sqlcmd to container --- cli/azd/cmd/container.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cli/azd/cmd/container.go b/cli/azd/cmd/container.go index 03424d0314c..fba74998a7b 100644 --- a/cli/azd/cmd/container.go +++ b/cli/azd/cmd/container.go @@ -61,6 +61,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/tools/maven" "github.com/azure/azure-dev/cli/azd/pkg/tools/npm" "github.com/azure/azure-dev/cli/azd/pkg/tools/python" + "github.com/azure/azure-dev/cli/azd/pkg/tools/sqlcmd" "github.com/azure/azure-dev/cli/azd/pkg/tools/swa" "github.com/azure/azure-dev/cli/azd/pkg/workflow" "github.com/mattn/go-colorable" @@ -602,6 +603,7 @@ func registerCommonDependencies(container *ioc.NestedContainer) { container.MustRegisterSingleton(dotnet.NewDotNetCli) container.MustRegisterSingleton(git.NewGitCli) container.MustRegisterSingleton(github.NewGitHubCli) + container.MustRegisterSingleton(sqlcmd.NewSqlCmdCli) container.MustRegisterSingleton(javac.NewCli) container.MustRegisterSingleton(kubectl.NewKubectl) container.MustRegisterSingleton(maven.NewMavenCli) From f5c968d764be0934baac869079e96265b0dadfd5 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Mon, 22 Jul 2024 21:22:02 +0000 Subject: [PATCH 03/12] tracking event --- cli/azd/internal/tracing/events/events.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cli/azd/internal/tracing/events/events.go b/cli/azd/internal/tracing/events/events.go index 0aac13f840d..75730ad36de 100644 --- a/cli/azd/internal/tracing/events/events.go +++ b/cli/azd/internal/tracing/events/events.go @@ -21,6 +21,9 @@ const BicepInstallEvent = "tools.bicep.install" // GitHubCliInstallEvent is the name of the event which tracks the overall GitHub cli install operation. const GitHubCliInstallEvent = "tools.gh.install" +// SqlCmdCliInstallEvent is the name of the event which tracks the overall sqlcmd cli install operation. +const SqlCmdCliInstallEvent = "tools.sqlCmd.install" + // PackCliInstallEvent is the name of the event which tracks the overall pack cli install operation. const PackCliInstallEvent = "tools.pack.install" From 5e207f125868bda38c94749ab9f38913b1d8a93a Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Mon, 22 Jul 2024 21:28:01 +0000 Subject: [PATCH 04/12] Support for bicep azd-metadata input parameters - login, id and type. --- cli/azd/pkg/auth/manager.go | 29 +++++++++ cli/azd/pkg/azure/arm_template.go | 3 + cli/azd/pkg/azureutil/principal.go | 47 +++++++++++++++ .../provisioning/bicep/bicep_provider.go | 60 +++++++++++++++---- .../infra/provisioning/bicep/prompt_test.go | 7 +++ .../current_principal_id_provider.go | 15 +++++ .../terraform/terraform_provider_test.go | 7 +++ cli/azd/pkg/tools/azcli/user_profile.go | 38 +++++++++++- cli/azd/pkg/tools/azcli/user_profile_test.go | 3 + cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go | 3 +- 10 files changed, 194 insertions(+), 18 deletions(-) diff --git a/cli/azd/pkg/auth/manager.go b/cli/azd/pkg/auth/manager.go index 07cfabb94f3..c5e8e30a94f 100644 --- a/cli/azd/pkg/auth/manager.go +++ b/cli/azd/pkg/auth/manager.go @@ -456,6 +456,35 @@ func (m *Manager) GetLoggedInServicePrincipalTenantID(ctx context.Context) (*str return currentUser.TenantID, nil } +func (m *Manager) GetLoggedInServicePrincipalID(ctx context.Context) (*string, error) { + if m.UseExternalAuth() { + // When delegating to an external system, we have no way to determine what principal was used + return nil, nil + } + + cfg, err := m.userConfigManager.Load() + if err != nil { + return nil, fmt.Errorf("fetching current user: %w", err) + } + + if shouldUseLegacyAuth(cfg) { + // When delegating to az, we have no way to determine what principal was used + return nil, nil + } + + authCfg, err := m.readAuthConfig() + if err != nil { + return nil, fmt.Errorf("fetching auth config: %w", err) + } + + currentUser, err := readUserProperties(authCfg) + if err != nil { + return nil, ErrNoCurrentUser + } + + return currentUser.ClientID, nil +} + func (m *Manager) newCredentialFromManagedIdentity(clientID string) (azcore.TokenCredential, error) { options := &azidentity.ManagedIdentityCredentialOptions{} if clientID != "" { diff --git a/cli/azd/pkg/azure/arm_template.go b/cli/azd/pkg/azure/arm_template.go index 267e5c33122..a8f4c4fc79d 100644 --- a/cli/azd/pkg/azure/arm_template.go +++ b/cli/azd/pkg/azure/arm_template.go @@ -104,6 +104,9 @@ type AzdMetadataType string const AzdMetadataTypeLocation AzdMetadataType = "location" const AzdMetadataTypeGenerate AzdMetadataType = "generate" +const AzdMetadataTypePrincipalLogin AzdMetadataType = "principalLogin" +const AzdMetadataTypePrincipalId AzdMetadataType = "principalId" +const AzdMetadataTypePrincipalType AzdMetadataType = "principalType" const AzdMetadataTypeGenerateOrManual AzdMetadataType = "generateOrManual" type AzdMetadata struct { diff --git a/cli/azd/pkg/azureutil/principal.go b/cli/azd/pkg/azureutil/principal.go index 7d63b005db8..481ffdea214 100644 --- a/cli/azd/pkg/azureutil/principal.go +++ b/cli/azd/pkg/azureutil/principal.go @@ -6,6 +6,7 @@ package azureutil import ( "context" "fmt" + "log" "github.com/azure/azure-dev/cli/azd/pkg/auth" "github.com/azure/azure-dev/cli/azd/pkg/tools/azcli" @@ -34,3 +35,49 @@ func GetCurrentPrincipalId(ctx context.Context, userProfile *azcli.UserProfileSe return oid, nil } + +type LoggedInPrincipalProfileData struct { + PrincipalId string + PrincipalType string + PrincipalLoginName string +} + +// LoggedInPrincipalProfile returns the info about the current logged in principal +func LoggedInPrincipalProfile( + ctx context.Context, userProfile *azcli.UserProfileService, tenantId string) (*LoggedInPrincipalProfileData, error) { + principalProfile, err := userProfile.SignedProfile(ctx, tenantId) + if err == nil { + return &LoggedInPrincipalProfileData{ + PrincipalId: principalProfile.Id, + PrincipalType: "User", + PrincipalLoginName: principalProfile.UserPrincipalName, + }, nil + } + + token, err := userProfile.GetAccessToken(ctx, tenantId) + if err != nil { + return nil, fmt.Errorf("getting access token: %w", err) + } + + tokenClaims, err := auth.GetClaimsFromAccessToken(token.AccessToken) + if err != nil { + return nil, fmt.Errorf("getting oid from token: %w", err) + } + + appProfile, err := userProfile.AppProfile(ctx, tenantId) + if err == nil { + return &LoggedInPrincipalProfileData{ + PrincipalId: *appProfile.AppId, + PrincipalType: "Application", + PrincipalLoginName: appProfile.DisplayName, + }, nil + } else { + log.Println(fmt.Errorf("fetching current user information: %w", err)) + } + + return &LoggedInPrincipalProfileData{ + PrincipalId: tokenClaims.LocalAccountId(), + PrincipalType: "User", + PrincipalLoginName: tokenClaims.Email, + }, nil +} diff --git a/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go b/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go index 53bffb72aeb..a8a52a9c743 100644 --- a/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go +++ b/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go @@ -1967,6 +1967,10 @@ func (p *BicepProvider) ensureParameters( key string param azure.ArmTemplateParameterDefinition } + currentPrincipalProfile, err := p.curPrincipal.CurrentPrincipalProfile(ctx) + if err != nil { + return nil, fmt.Errorf("fetching current principal profile: %w", err) + } for _, key := range sortedKeys { param := template.Parameters[key] @@ -2009,20 +2013,50 @@ func (p *BicepProvider) ensureParameters( // If the parameter is tagged with {type: "generate"}, skip prompting. // We generate it once, then save to config for next attempts.`. azdMetadata, hasMetadata := param.AzdMetadata() - if hasMetadata && parameterType == ParameterTypeString && azdMetadata.Type != nil && - *azdMetadata.Type == azure.AzdMetadataTypeGenerate { - - // - generate once - genValue, err := autoGenerate(key, azdMetadata) - if err != nil { - return nil, err - } - configuredParameters[key] = azure.ArmParameterValue{ - Value: genValue, + if hasMetadata && parameterType == ParameterTypeString && azdMetadata.Type != nil { + azdMetadataType := *azdMetadata.Type + switch azdMetadataType { + case azure.AzdMetadataTypeGenerate: + // - generate once + genValue, err := autoGenerate(key, azdMetadata) + if err != nil { + return nil, err + } + configuredParameters[key] = azure.ArmParameterValue{ + Value: genValue, + } + mustSetParamAsConfig(key, genValue, p.env.Config, param.Secure()) + configModified = true + continue + // Check metadata for auto-inject values [principalId, principalType, principalLogin] + case azure.AzdMetadataTypePrincipalLogin: + pLogin := currentPrincipalProfile.PrincipalLoginName + configuredParameters[key] = azure.ArmParameterValue{ + Value: pLogin, + } + mustSetParamAsConfig(key, pLogin, p.env.Config, param.Secure()) + configModified = true + continue + case azure.AzdMetadataTypePrincipalId: + pLogin := currentPrincipalProfile.PrincipalId + configuredParameters[key] = azure.ArmParameterValue{ + Value: pLogin, + } + mustSetParamAsConfig(key, pLogin, p.env.Config, param.Secure()) + configModified = true + continue + case azure.AzdMetadataTypePrincipalType: + pLogin := currentPrincipalProfile.PrincipalType + configuredParameters[key] = azure.ArmParameterValue{ + Value: pLogin, + } + mustSetParamAsConfig(key, pLogin, p.env.Config, param.Secure()) + configModified = true + continue + default: + // Do nothing + log.Println("Skipping actions for azd unknown metadata bicep parameter with type: ", azdMetadataType) } - mustSetParamAsConfig(key, genValue, p.env.Config, param.Secure()) - configModified = true - continue } // No saved value for this required parameter, we'll need to prompt for it. diff --git a/cli/azd/pkg/infra/provisioning/bicep/prompt_test.go b/cli/azd/pkg/infra/provisioning/bicep/prompt_test.go index 68a5010f90f..b5b8dd0e324 100644 --- a/cli/azd/pkg/infra/provisioning/bicep/prompt_test.go +++ b/cli/azd/pkg/infra/provisioning/bicep/prompt_test.go @@ -8,6 +8,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/account" "github.com/azure/azure-dev/cli/azd/pkg/azure" + "github.com/azure/azure-dev/cli/azd/pkg/azureutil" "github.com/azure/azure-dev/cli/azd/pkg/cloud" "github.com/azure/azure-dev/cli/azd/pkg/convert" "github.com/azure/azure-dev/cli/azd/pkg/environment" @@ -320,6 +321,12 @@ func TestPromptForParametersLocation(t *testing.T) { type mockCurrentPrincipal struct{} +// CurrentPrincipalProfile implements provisioning.CurrentPrincipalIdProvider. +func (m *mockCurrentPrincipal) CurrentPrincipalProfile( + ctx context.Context) (*azureutil.LoggedInPrincipalProfileData, error) { + return &azureutil.LoggedInPrincipalProfileData{}, nil +} + func (m *mockCurrentPrincipal) CurrentPrincipalId(_ context.Context) (string, error) { return "11111111-1111-1111-1111-111111111111", nil } diff --git a/cli/azd/pkg/infra/provisioning/current_principal_id_provider.go b/cli/azd/pkg/infra/provisioning/current_principal_id_provider.go index cac98fe3257..87529051566 100644 --- a/cli/azd/pkg/infra/provisioning/current_principal_id_provider.go +++ b/cli/azd/pkg/infra/provisioning/current_principal_id_provider.go @@ -14,6 +14,7 @@ type CurrentPrincipalIdProvider interface { // CurrentPrincipalId returns the object id of the current logged in principal, or an error if it can not be // determined. CurrentPrincipalId(ctx context.Context) (string, error) + CurrentPrincipalProfile(ctx context.Context) (*azureutil.LoggedInPrincipalProfileData, error) } func NewPrincipalIdProvider( @@ -47,3 +48,17 @@ func (p *principalIDProvider) CurrentPrincipalId(ctx context.Context) (string, e return principalId, nil } + +func (p *principalIDProvider) CurrentPrincipalProfile(ctx context.Context) (*azureutil.LoggedInPrincipalProfileData, error) { + tenantId, err := p.subResolver.LookupTenant(ctx, p.env.GetSubscriptionId()) + if err != nil { + return nil, fmt.Errorf("getting tenant id for subscription %s. Error: %w", p.env.GetSubscriptionId(), err) + } + + principalProfile, err := azureutil.LoggedInPrincipalProfile(ctx, p.userProfileService, tenantId) + if err != nil { + return nil, fmt.Errorf("fetching current user information: %w", err) + } + + return principalProfile, nil +} diff --git a/cli/azd/pkg/infra/provisioning/terraform/terraform_provider_test.go b/cli/azd/pkg/infra/provisioning/terraform/terraform_provider_test.go index 9d2c8f239d7..a46767c183f 100644 --- a/cli/azd/pkg/infra/provisioning/terraform/terraform_provider_test.go +++ b/cli/azd/pkg/infra/provisioning/terraform/terraform_provider_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/azure/azure-dev/cli/azd/pkg/account" + "github.com/azure/azure-dev/cli/azd/pkg/azureutil" "github.com/azure/azure-dev/cli/azd/pkg/cloud" "github.com/azure/azure-dev/cli/azd/pkg/environment" "github.com/azure/azure-dev/cli/azd/pkg/exec" @@ -215,6 +216,12 @@ func prepareDestroyMocks(commandRunner *mockexec.MockCommandRunner) { type mockCurrentPrincipal struct{} +// CurrentPrincipalProfile implements provisioning.CurrentPrincipalIdProvider. +func (m *mockCurrentPrincipal) CurrentPrincipalProfile( + ctx context.Context) (*azureutil.LoggedInPrincipalProfileData, error) { + return &azureutil.LoggedInPrincipalProfileData{}, nil +} + func (m *mockCurrentPrincipal) CurrentPrincipalId(_ context.Context) (string, error) { return "11111111-1111-1111-1111-111111111111", nil } diff --git a/cli/azd/pkg/tools/azcli/user_profile.go b/cli/azd/pkg/tools/azcli/user_profile.go index a3cfae51e90..9ea5425cd7f 100644 --- a/cli/azd/pkg/tools/azcli/user_profile.go +++ b/cli/azd/pkg/tools/azcli/user_profile.go @@ -18,12 +18,14 @@ type UserProfileService struct { credentialProvider auth.MultiTenantCredentialProvider coreClientOptions *azcore.ClientOptions cloud *cloud.Cloud + authManager *auth.Manager } func NewUserProfileService( credentialProvider auth.MultiTenantCredentialProvider, clientOptionsBuilderFactory *azsdk.ClientOptionsBuilderFactory, cloud *cloud.Cloud, + authManager *auth.Manager, ) *UserProfileService { coreClientOptions := clientOptionsBuilderFactory.NewClientOptionsBuilder(). WithCloud(cloud.Configuration). @@ -34,6 +36,7 @@ func NewUserProfileService( credentialProvider: credentialProvider, coreClientOptions: coreClientOptions, cloud: cloud, + authManager: authManager, } } @@ -52,17 +55,46 @@ func (u *UserProfileService) createGraphClient(ctx context.Context, tenantId str } func (user *UserProfileService) GetSignedInUserId(ctx context.Context, tenantId string) (string, error) { - client, err := user.createGraphClient(ctx, tenantId) + userProfile, err := user.SignedProfile(ctx, tenantId) if err != nil { return "", err } + return userProfile.Id, nil +} + +func (user *UserProfileService) SignedProfile(ctx context.Context, tenantId string) (*graphsdk.UserProfile, error) { + client, err := user.createGraphClient(ctx, tenantId) + if err != nil { + return nil, err + } + userProfile, err := client.Me().Get(ctx) if err != nil { - return "", fmt.Errorf("failed retrieving current user profile: %w", err) + return nil, fmt.Errorf("failed retrieving current user profile: %w", err) } - return userProfile.Id, nil + return userProfile, nil +} + +func (user *UserProfileService) AppProfile( + ctx context.Context, tenantId string) (*graphsdk.Application, error) { + client, err := user.createGraphClient(ctx, tenantId) + if err != nil { + return nil, err + } + + appId, err := user.authManager.GetLoggedInServicePrincipalID(ctx) + if err != nil { + return nil, fmt.Errorf("getting logged in service principal ID: %w", err) + } + + appProfile, err := client.ApplicationById(*appId).GetByAppId(ctx) + if err != nil { + return nil, fmt.Errorf("failed retrieving current user profile: %w", err) + } + + return appProfile, nil } func (u *UserProfileService) GetAccessToken(ctx context.Context, tenantId string) (*AzCliAccessToken, error) { diff --git a/cli/azd/pkg/tools/azcli/user_profile_test.go b/cli/azd/pkg/tools/azcli/user_profile_test.go index 5c897ed346b..98c5688a92b 100644 --- a/cli/azd/pkg/tools/azcli/user_profile_test.go +++ b/cli/azd/pkg/tools/azcli/user_profile_test.go @@ -38,6 +38,7 @@ func Test_GetUserAccessToken(t *testing.T) { }, clientOptionsBuilderFactory, cloud.AzurePublic(), + nil, ) actual, err := userProfile.GetAccessToken(*mockContext.Context, "") @@ -68,6 +69,7 @@ func Test_GetSignedInUserId(t *testing.T) { &mocks.MockMultiTenantCredentialProvider{}, clientOptionsBuilderFactory, cloud.AzurePublic(), + nil, ) userId, err := userProfile.GetSignedInUserId(*mockContext.Context, "") @@ -88,6 +90,7 @@ func Test_GetSignedInUserId(t *testing.T) { &mocks.MockMultiTenantCredentialProvider{}, clientOptionsBuilderFactory, cloud.AzurePublic(), + nil, ) userId, err := userProfile.GetSignedInUserId(*mockContext.Context, "") diff --git a/cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go b/cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go index d81bfaf0425..c2cd0d95376 100644 --- a/cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go +++ b/cli/azd/pkg/tools/sqlcmd/sqlcmd_test.go @@ -8,7 +8,6 @@ import ( "archive/zip" "bytes" "context" - "fmt" "io" "net/http" "os" @@ -43,7 +42,7 @@ func TestNewSqlCmdHubCli(t *testing.T) { return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" }).Respond(exec.NewRunResult( 0, - fmt.Sprintf("%s", sqlCmdCliVersion.String()), + sqlCmdCliVersion.String(), "", )) From a7ba189688ba87ed2fd7c0223f7fb05716086aa9 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Mon, 22 Jul 2024 21:28:50 +0000 Subject: [PATCH 05/12] adds azd operation sqlScript to connect to a sql server with sqlcmd and run a sql script --- cli/azd/pkg/infra/provisioning/manager.go | 132 ++++++++++++++++-- .../pkg/infra/provisioning/manager_test.go | 35 +++-- 2 files changed, 149 insertions(+), 18 deletions(-) diff --git a/cli/azd/pkg/infra/provisioning/manager.go b/cli/azd/pkg/infra/provisioning/manager.go index 13325cff497..570721f0b43 100644 --- a/cli/azd/pkg/infra/provisioning/manager.go +++ b/cli/azd/pkg/infra/provisioning/manager.go @@ -22,6 +22,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/output" "github.com/azure/azure-dev/cli/azd/pkg/output/ux" "github.com/azure/azure-dev/cli/azd/pkg/prompt" + "github.com/azure/azure-dev/cli/azd/pkg/tools/sqlcmd" "gopkg.in/yaml.v3" ) @@ -40,6 +41,7 @@ type Manager struct { options *Options fileShareService storage.FileShareService cloud *cloud.Cloud + sqlCmdCli *sqlcmd.SqlCmdCli } // defaultOptions for this package. @@ -103,12 +105,16 @@ func (m *Manager) Deploy(ctx context.Context) (*DeployResult, error) { if !filepath.IsAbs(infraRoot) { infraRoot = filepath.Join(m.projectPath, m.options.Path) } - bindMountOperations, err := azdFileShareUploadOperations(infraRoot, *m.env) + model, err := azdOperations(infraRoot, *m.env) + if err != nil { + return nil, err + } + bindMountOperations, err := azdFileShareUploadOperations(model) azdOperationsEnabled := m.alphaFeatureManager.IsEnabled(AzdOperationsFeatureKey) if !azdOperationsEnabled && len(bindMountOperations) > 0 { m.console.Message(ctx, ErrBindMountOperationDisabled.Error()) } - if azdOperationsEnabled { + if azdOperationsEnabled && len(bindMountOperations) > 0 { if err != nil { return nil, fmt.Errorf("looking for azd fileShare upload operations: %w", err) } @@ -118,6 +124,20 @@ func (m *Manager) Deploy(ctx context.Context) (*DeployResult, error) { } } + sqlServerOperations, err := azdSqlServerOperations(model, filepath.Join(m.projectPath, m.options.Path)) + if !azdOperationsEnabled && len(sqlServerOperations) > 0 { + m.console.Message(ctx, ErrSqlScriptOperationDisabled.Error()) + } + if azdOperationsEnabled { + if err != nil { + return nil, fmt.Errorf("looking for azd sql scripts operations: %w", err) + } + if err := doSqlScriptOperation( + ctx, sqlServerOperations, m.console, *m.env, m.sqlCmdCli); err != nil { + return nil, err + } + } + // make sure any spinner is stopped m.console.StopSpinner(ctx, "", input.StepDone) @@ -126,6 +146,7 @@ func (m *Manager) Deploy(ctx context.Context) (*DeployResult, error) { const ( fileShareUploadOperation string = "FileShareUpload" + sqlServerOperation string = "SqlScript" azdOperationsFileName string = "azd.operations.yaml" ) @@ -142,6 +163,14 @@ type azdOperationFileShareUpload struct { Path string } +type azdOperationSqlServer struct { + Description string + Server string + Database string + Path string + Env map[string]string +} + type azdOperationsModel struct { Operations []azdOperation } @@ -175,12 +204,7 @@ func azdOperations(infraPath string, env environment.Environment) (azdOperations return operations, nil } -func azdFileShareUploadOperations(infraPath string, env environment.Environment) ([]azdOperationFileShareUpload, error) { - model, err := azdOperations(infraPath, env) - if err != nil { - return nil, err - } - +func azdFileShareUploadOperations(model azdOperationsModel) ([]azdOperationFileShareUpload, error) { var fileShareUploadOperations []azdOperationFileShareUpload for _, operation := range model.Operations { if operation.Type == fileShareUploadOperation { @@ -200,6 +224,29 @@ func azdFileShareUploadOperations(infraPath string, env environment.Environment) return fileShareUploadOperations, nil } +func azdSqlServerOperations(model azdOperationsModel, infraPath string) ([]azdOperationSqlServer, error) { + var sqlServerOperations []azdOperationSqlServer + for _, operation := range model.Operations { + if operation.Type == sqlServerOperation { + var sqlServerScript azdOperationSqlServer + bytes, err := json.Marshal(operation.Config) + if err != nil { + return nil, err + } + err = json.Unmarshal(bytes, &sqlServerScript) + if err != nil { + return nil, err + } + sqlServerScript.Description = operation.Description + if !filepath.IsAbs(sqlServerScript.Path) { + sqlServerScript.Path = filepath.Join(infraPath, sqlServerScript.Path) + } + sqlServerOperations = append(sqlServerOperations, sqlServerScript) + } + } + return sqlServerOperations, nil +} + var ErrAzdOperationsNotEnabled = fmt.Errorf(fmt.Sprintf( "azd operations (alpha feature) is required but disabled. You can enable azd operations by running: %s", output.WithGrayFormat(alpha.GetEnableCommand(AzdOperationsFeatureKey)))) @@ -211,6 +258,13 @@ var ErrBindMountOperationDisabled = fmt.Errorf( output.WithWarningFormat("Ignoring bind mounts."), ) +var ErrSqlScriptOperationDisabled = fmt.Errorf( + "%sYour project has sql server scripts.\n - %w\n%s\n", + output.WithWarningFormat("*Note: "), + ErrAzdOperationsNotEnabled, + output.WithWarningFormat("Ignoring scripts."), +) + func doBindMountOperation( ctx context.Context, fileShareUploadOperations []azdOperationFileShareUpload, @@ -251,6 +305,66 @@ func bindMountOperation( return fileShareService.UploadPath(ctx, subId, shareUrl, source) } +func doSqlScriptOperation( + ctx context.Context, + SqlScriptsOperations []azdOperationSqlServer, + console input.Console, + env environment.Environment, + sqlCmdCli *sqlcmd.SqlCmdCli, +) error { + if len(SqlScriptsOperations) > 0 { + console.ShowSpinner(ctx, "execute sql scripts", input.StepFailed) + } + for _, op := range SqlScriptsOperations { + filePath := op.Path + if op.Env != nil { + fileEnv := environment.NewWithValues("fileEnv", op.Env) + tmpDir, err := os.MkdirTemp("", "azd-sql-scripts") + if err != nil { + return err + } + defer os.RemoveAll(tmpDir) + data, err := os.ReadFile(filePath) + if err != nil { + return err + } + expString := osutil.NewExpandableString(string(data)) + evaluated, err := expString.Envsubst(fileEnv.Getenv) + if err != nil { + return err + } + filePath = filepath.Join(tmpDir, filepath.Base(filePath)) + err = os.WriteFile(filePath, []byte(evaluated), osutil.PermissionDirectory) + if err != nil { + return err + } + } + + if _, err := sqlCmdCli.ExecuteScript( + ctx, + op.Server, + op.Database, + filePath, + // sqlCmd cli uses DAC to connect to the server, but it doesn't know how to handle multi-tenant accounts. + // sqlCmd cli asks az or azd for a token w/o passing a tenant-id arg. + // sqlCmd cli runs from ~/.azd/bin: + // - azd doesn't know the tenant-id to use and defaults to get a token for home tenant. + // By setting the AZURE_SUBSCRIPTION_ID as env var to run sqlCmd cli, azd will use it to get tenant-id. + []string{ + fmt.Sprintf("%s=%s", environment.SubscriptionIdEnvVarName, env.GetSubscriptionId()), + }, + ); err != nil { + return fmt.Errorf("error run sqlcmd: %w", err) + } + console.MessageUxItem(ctx, &ux.DisplayedResource{ + Type: sqlServerOperation, + Name: op.Description, + State: ux.SucceededState, + }) + } + return nil +} + // Preview generates the list of changes to be applied as part of the provisioning. func (m *Manager) Preview(ctx context.Context) (*DeployPreviewResult, error) { // Apply the infrastructure deployment @@ -389,6 +503,7 @@ func NewManager( alphaFeatureManager *alpha.FeatureManager, fileShareService storage.FileShareService, cloud *cloud.Cloud, + sqlCmdCli *sqlcmd.SqlCmdCli, ) *Manager { return &Manager{ serviceLocator: serviceLocator, @@ -399,6 +514,7 @@ func NewManager( alphaFeatureManager: alphaFeatureManager, fileShareService: fileShareService, cloud: cloud, + sqlCmdCli: sqlCmdCli, } } diff --git a/cli/azd/pkg/infra/provisioning/manager_test.go b/cli/azd/pkg/infra/provisioning/manager_test.go index 3defbc37aa2..fff4315ed7e 100644 --- a/cli/azd/pkg/infra/provisioning/manager_test.go +++ b/cli/azd/pkg/infra/provisioning/manager_test.go @@ -17,6 +17,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/prompt" "github.com/azure/azure-dev/cli/azd/pkg/tools/azcli" + "github.com/azure/azure-dev/cli/azd/pkg/tools/sqlcmd" "github.com/azure/azure-dev/cli/azd/test/mocks" "github.com/azure/azure-dev/cli/azd/test/mocks/mockaccount" "github.com/azure/azure-dev/cli/azd/test/mocks/mockazcli" @@ -43,7 +44,8 @@ func TestProvisionInitializesEnvironment(t *testing.T) { }) registerContainerDependencies(mockContext, env) - + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) + require.NoError(t, err) envManager := &mockenv.MockEnvManager{} mgr := NewManager( mockContext.Container, @@ -54,8 +56,9 @@ func TestProvisionInitializesEnvironment(t *testing.T) { mockContext.AlphaFeaturesManager, nil, cloud.AzurePublic(), + sqlcmd, ) - err := mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) + err = mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) require.NoError(t, err) require.Equal(t, "00000000-0000-0000-0000-000000000000", env.GetSubscriptionId()) @@ -70,6 +73,8 @@ func TestManagerPreview(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) + require.NoError(t, err) envManager := &mockenv.MockEnvManager{} mgr := NewManager( @@ -81,8 +86,9 @@ func TestManagerPreview(t *testing.T) { mockContext.AlphaFeaturesManager, nil, cloud.AzurePublic(), + sqlcmd, ) - err := mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) + err = mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) require.NoError(t, err) deploymentPlan, err := mgr.Preview(*mockContext.Context) @@ -99,7 +105,8 @@ func TestManagerGetState(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) - + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) + require.NoError(t, err) envManager := &mockenv.MockEnvManager{} mgr := NewManager( mockContext.Container, @@ -110,8 +117,9 @@ func TestManagerGetState(t *testing.T) { mockContext.AlphaFeaturesManager, nil, cloud.AzurePublic(), + sqlcmd, ) - err := mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) + err = mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) require.NoError(t, err) getResult, err := mgr.State(*mockContext.Context, nil) @@ -128,7 +136,8 @@ func TestManagerDeploy(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) - + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) + require.NoError(t, err) envManager := &mockenv.MockEnvManager{} mgr := NewManager( mockContext.Container, @@ -139,8 +148,9 @@ func TestManagerDeploy(t *testing.T) { mockContext.AlphaFeaturesManager, nil, cloud.AzurePublic(), + sqlcmd, ) - err := mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) + err = mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) require.NoError(t, err) deployResult, err := mgr.Deploy(*mockContext.Context) @@ -164,6 +174,8 @@ func TestManagerDestroyWithPositiveConfirmation(t *testing.T) { envManager := &mockenv.MockEnvManager{} envManager.On("Save", *mockContext.Context, env).Return(nil) + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) + require.NoError(t, err) mgr := NewManager( mockContext.Container, @@ -174,8 +186,9 @@ func TestManagerDestroyWithPositiveConfirmation(t *testing.T) { mockContext.AlphaFeaturesManager, nil, cloud.AzurePublic(), + sqlcmd, ) - err := mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) + err = mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) require.NoError(t, err) destroyOptions := NewDestroyOptions(false, false) @@ -199,7 +212,8 @@ func TestManagerDestroyWithNegativeConfirmation(t *testing.T) { }).Respond(false) registerContainerDependencies(mockContext, env) - + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) + require.NoError(t, err) envManager := &mockenv.MockEnvManager{} mgr := NewManager( mockContext.Container, @@ -210,8 +224,9 @@ func TestManagerDestroyWithNegativeConfirmation(t *testing.T) { mockContext.AlphaFeaturesManager, nil, cloud.AzurePublic(), + sqlcmd, ) - err := mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) + err = mgr.Initialize(*mockContext.Context, "", Options{Provider: "test"}) require.NoError(t, err) destroyOptions := NewDestroyOptions(false, false) From 2af1753907a34bfe372b2e8ab4a56d65f0adb681 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Tue, 23 Jul 2024 01:05:52 +0000 Subject: [PATCH 06/12] mocks --- .../pkg/infra/provisioning/manager_test.go | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/cli/azd/pkg/infra/provisioning/manager_test.go b/cli/azd/pkg/infra/provisioning/manager_test.go index fff4315ed7e..1dd8262ce1b 100644 --- a/cli/azd/pkg/infra/provisioning/manager_test.go +++ b/cli/azd/pkg/infra/provisioning/manager_test.go @@ -11,6 +11,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/account" "github.com/azure/azure-dev/cli/azd/pkg/cloud" "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/exec" "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" . "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning/test" @@ -42,6 +43,9 @@ func TestProvisionInitializesEnvironment(t *testing.T) { // Select the first from the list return 0, nil }) + mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "sqlcmd --version") + }).Respond(exec.NewRunResult(0, "1.8.0", "")) registerContainerDependencies(mockContext, env) sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) @@ -73,6 +77,10 @@ func TestManagerPreview(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) + mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "sqlcmd --version") + }).Respond(exec.NewRunResult(0, "1.8.0", "")) + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) require.NoError(t, err) @@ -105,6 +113,10 @@ func TestManagerGetState(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) + mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "sqlcmd --version") + }).Respond(exec.NewRunResult(0, "1.8.0", "")) + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) require.NoError(t, err) envManager := &mockenv.MockEnvManager{} @@ -136,6 +148,10 @@ func TestManagerDeploy(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) + mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "sqlcmd --version") + }).Respond(exec.NewRunResult(0, "1.8.0", "")) + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) require.NoError(t, err) envManager := &mockenv.MockEnvManager{} @@ -171,6 +187,9 @@ func TestManagerDestroyWithPositiveConfirmation(t *testing.T) { }).Respond(true) registerContainerDependencies(mockContext, env) + mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "sqlcmd --version") + }).Respond(exec.NewRunResult(0, "1.8.0", "")) envManager := &mockenv.MockEnvManager{} envManager.On("Save", *mockContext.Context, env).Return(nil) @@ -212,6 +231,10 @@ func TestManagerDestroyWithNegativeConfirmation(t *testing.T) { }).Respond(false) registerContainerDependencies(mockContext, env) + mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { + return strings.Contains(command, "sqlcmd --version") + }).Respond(exec.NewRunResult(0, "1.8.0", "")) + sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) require.NoError(t, err) envManager := &mockenv.MockEnvManager{} From 2890f5d2bd9eae1eade94b7bddf30b530c88723c Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Tue, 23 Jul 2024 02:22:24 +0000 Subject: [PATCH 07/12] cross os --- cli/azd/pkg/infra/provisioning/manager_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cli/azd/pkg/infra/provisioning/manager_test.go b/cli/azd/pkg/infra/provisioning/manager_test.go index 1dd8262ce1b..62b9b11b489 100644 --- a/cli/azd/pkg/infra/provisioning/manager_test.go +++ b/cli/azd/pkg/infra/provisioning/manager_test.go @@ -44,7 +44,7 @@ func TestProvisionInitializesEnvironment(t *testing.T) { return 0, nil }) mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { - return strings.Contains(command, "sqlcmd --version") + return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" }).Respond(exec.NewRunResult(0, "1.8.0", "")) registerContainerDependencies(mockContext, env) @@ -78,7 +78,7 @@ func TestManagerPreview(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { - return strings.Contains(command, "sqlcmd --version") + return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" }).Respond(exec.NewRunResult(0, "1.8.0", "")) sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) @@ -114,7 +114,7 @@ func TestManagerGetState(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { - return strings.Contains(command, "sqlcmd --version") + return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" }).Respond(exec.NewRunResult(0, "1.8.0", "")) sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) @@ -149,7 +149,7 @@ func TestManagerDeploy(t *testing.T) { mockContext := mocks.NewMockContext(context.Background()) registerContainerDependencies(mockContext, env) mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { - return strings.Contains(command, "sqlcmd --version") + return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" }).Respond(exec.NewRunResult(0, "1.8.0", "")) sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) @@ -188,7 +188,7 @@ func TestManagerDestroyWithPositiveConfirmation(t *testing.T) { registerContainerDependencies(mockContext, env) mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { - return strings.Contains(command, "sqlcmd --version") + return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" }).Respond(exec.NewRunResult(0, "1.8.0", "")) envManager := &mockenv.MockEnvManager{} @@ -232,7 +232,7 @@ func TestManagerDestroyWithNegativeConfirmation(t *testing.T) { registerContainerDependencies(mockContext, env) mockContext.CommandRunner.When(func(args exec.RunArgs, command string) bool { - return strings.Contains(command, "sqlcmd --version") + return strings.Contains(args.Cmd, "sqlcmd") && len(args.Args) == 1 && args.Args[0] == "--version" }).Respond(exec.NewRunResult(0, "1.8.0", "")) sqlcmd, err := sqlcmd.NewSqlCmdCli(*mockContext.Context, mockContext.Console, mockContext.CommandRunner) From 44c2f5632aa7dc79e21615280485f2e67b23165d Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Wed, 24 Jul 2024 23:16:27 +0000 Subject: [PATCH 08/12] use project lifecycle events for operations. Register to postprovision --- cli/azd/internal/cmd/provision.go | 15 + cli/azd/pkg/apphost/generate.go | 4 +- cli/azd/pkg/apphost/generate_test.go | 4 +- cli/azd/pkg/infra/provisioning/manager.go | 310 +++--------------- .../provisioning/operations/azd_operation.go | 68 ++++ .../operations/file_share_upload.go | 90 +++++ .../provisioning/operations/sql_script.go | 117 +++++++ cli/azd/pkg/project/dotnet_importer.go | 9 +- 8 files changed, 342 insertions(+), 275 deletions(-) create mode 100644 cli/azd/pkg/infra/provisioning/operations/azd_operation.go create mode 100644 cli/azd/pkg/infra/provisioning/operations/file_share_upload.go create mode 100644 cli/azd/pkg/infra/provisioning/operations/sql_script.go diff --git a/cli/azd/internal/cmd/provision.go b/cli/azd/internal/cmd/provision.go index f7d1fda84b1..1c186d68580 100644 --- a/cli/azd/internal/cmd/provision.go +++ b/cli/azd/internal/cmd/provision.go @@ -191,6 +191,21 @@ func (p *ProvisionAction) Run(ctx context.Context) (*actions.ActionResult, error return nil, fmt.Errorf("initializing provisioning manager: %w", err) } + // register operations + operations, err := p.provisionManager.Operations(ctx) + if err != nil { + return nil, fmt.Errorf("registering operations: %w", err) + } + for _, operation := range operations { + err := p.projectConfig.AddHandler( + "postprovision", func(ctx context.Context, args project.ProjectLifecycleEventArgs) error { + return operation(ctx) + }) + if err != nil { + return nil, fmt.Errorf("registering operation: %w", err) + } + } + // Get Subscription to Display in Command Title Note // Subscription and Location are ONLY displayed when they are available (found from env), otherwise, this message // is not displayed. diff --git a/cli/azd/pkg/apphost/generate.go b/cli/azd/pkg/apphost/generate.go index efa94837288..3f5135f8d84 100644 --- a/cli/azd/pkg/apphost/generate.go +++ b/cli/azd/pkg/apphost/generate.go @@ -23,7 +23,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/convert" "github.com/azure/azure-dev/cli/azd/pkg/custommaps" "github.com/azure/azure-dev/cli/azd/pkg/environment" - "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" + "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning/operations" "github.com/azure/azure-dev/cli/azd/pkg/osutil" "github.com/azure/azure-dev/cli/azd/pkg/output" "github.com/azure/azure-dev/cli/azd/resources" @@ -318,7 +318,7 @@ func BicepTemplate(name string, manifest *Manifest, options AppHostOptions) (*me } } else { // returning fs because this error can be handled by the caller as expected - return fs, provisioning.ErrBindMountOperationDisabled + return fs, operations.ErrBindMountOperationDisabled } } diff --git a/cli/azd/pkg/apphost/generate_test.go b/cli/azd/pkg/apphost/generate_test.go index 4533b451159..200e9707615 100644 --- a/cli/azd/pkg/apphost/generate_test.go +++ b/cli/azd/pkg/apphost/generate_test.go @@ -12,7 +12,7 @@ import ( "testing" "github.com/azure/azure-dev/cli/azd/pkg/exec" - "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" + "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning/operations" "github.com/azure/azure-dev/cli/azd/pkg/osutil" "github.com/azure/azure-dev/cli/azd/pkg/tools/dotnet" "github.com/azure/azure-dev/cli/azd/test/mocks" @@ -228,7 +228,7 @@ func TestAspireContainerGeneration(t *testing.T) { } _, err = BicepTemplate("main", m, AppHostOptions{}) - require.Error(t, err, provisioning.ErrBindMountOperationDisabled) + require.Error(t, err, operations.ErrBindMountOperationDisabled) files, err := BicepTemplate("main", m, AppHostOptions{ AzdOperations: true, diff --git a/cli/azd/pkg/infra/provisioning/manager.go b/cli/azd/pkg/infra/provisioning/manager.go index 570721f0b43..e128676a8b6 100644 --- a/cli/azd/pkg/infra/provisioning/manager.go +++ b/cli/azd/pkg/infra/provisioning/manager.go @@ -6,9 +6,7 @@ package provisioning import ( "context" "encoding/json" - "errors" "fmt" - "os" "path/filepath" "github.com/azure/azure-dev/cli/azd/pkg/alpha" @@ -16,14 +14,11 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/cloud" "github.com/azure/azure-dev/cli/azd/pkg/environment" "github.com/azure/azure-dev/cli/azd/pkg/infra" + "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning/operations" "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/ioc" - "github.com/azure/azure-dev/cli/azd/pkg/osutil" - "github.com/azure/azure-dev/cli/azd/pkg/output" - "github.com/azure/azure-dev/cli/azd/pkg/output/ux" "github.com/azure/azure-dev/cli/azd/pkg/prompt" "github.com/azure/azure-dev/cli/azd/pkg/tools/sqlcmd" - "gopkg.in/yaml.v3" ) type DefaultProviderResolver func() (ProviderKind, error) @@ -50,6 +45,48 @@ const ( defaultPath = "infra" ) +func (m *Manager) Operations(ctx context.Context) ([]func(ctx context.Context) error, error) { + //Get a list of operations + result := []func(ctx context.Context) error{} + infraRoot := m.options.Path + if !filepath.IsAbs(infraRoot) { + infraRoot = filepath.Join(m.projectPath, m.options.Path) + } + model, err := operations.AzdOperations(infraRoot, *m.env) + if err != nil { + return result, err + } + bindMountOperations, err := operations.FileShareUploads(model) + azdOperationsEnabled := m.alphaFeatureManager.IsEnabled(operations.AzdOperationsFeatureKey) + if !azdOperationsEnabled && len(bindMountOperations) > 0 { + m.console.Message(ctx, operations.ErrBindMountOperationDisabled.Error()) + } + if azdOperationsEnabled && len(bindMountOperations) > 0 { + if err != nil { + return result, fmt.Errorf("looking for azd fileShare upload operations: %w", err) + } + result = append(result, func(context context.Context) error { + return operations.DoBindMount( + context, bindMountOperations, m.env, m.console, m.fileShareService, m.cloud.StorageEndpointSuffix) + }) + } + + sqlServerOperations, err := operations.SqlScripts(model, filepath.Join(m.projectPath, m.options.Path)) + if !azdOperationsEnabled && len(sqlServerOperations) > 0 { + m.console.Message(ctx, operations.ErrSqlScriptOperationDisabled.Error()) + } + if azdOperationsEnabled { + if err != nil { + return result, fmt.Errorf("looking for azd sql scripts operations: %w", err) + } + result = append(result, func(context context.Context) error { + return operations.DoSqlScript( + context, sqlServerOperations, m.console, *m.env, m.sqlCmdCli) + }) + } + return result, nil +} + func (m *Manager) Initialize(ctx context.Context, projectPath string, options Options) error { // applied defaults if missing if options.Module == "" { @@ -61,7 +98,6 @@ func (m *Manager) Initialize(ctx context.Context, projectPath string, options Op m.projectPath = projectPath m.options = &options - provider, err := m.newProvider(ctx) if err != nil { return fmt.Errorf("initializing infrastructure provider: %w", err) @@ -81,8 +117,6 @@ func (m *Manager) State(ctx context.Context, options *StateOptions) (*StateResul return result, nil } -var AzdOperationsFeatureKey = alpha.MustFeatureKey("azd.operations") - // Deploys the Azure infrastructure for the specified project func (m *Manager) Deploy(ctx context.Context) (*DeployResult, error) { // Apply the infrastructure deployment @@ -101,270 +135,12 @@ func (m *Manager) Deploy(ctx context.Context) (*DeployResult, error) { return nil, fmt.Errorf("updating environment with deployment outputs: %w", err) } - infraRoot := m.options.Path - if !filepath.IsAbs(infraRoot) { - infraRoot = filepath.Join(m.projectPath, m.options.Path) - } - model, err := azdOperations(infraRoot, *m.env) - if err != nil { - return nil, err - } - bindMountOperations, err := azdFileShareUploadOperations(model) - azdOperationsEnabled := m.alphaFeatureManager.IsEnabled(AzdOperationsFeatureKey) - if !azdOperationsEnabled && len(bindMountOperations) > 0 { - m.console.Message(ctx, ErrBindMountOperationDisabled.Error()) - } - if azdOperationsEnabled && len(bindMountOperations) > 0 { - if err != nil { - return nil, fmt.Errorf("looking for azd fileShare upload operations: %w", err) - } - if err := doBindMountOperation( - ctx, bindMountOperations, *m.env, m.console, m.fileShareService, m.cloud.StorageEndpointSuffix); err != nil { - return nil, fmt.Errorf("error running bind mount operation: %w", err) - } - } - - sqlServerOperations, err := azdSqlServerOperations(model, filepath.Join(m.projectPath, m.options.Path)) - if !azdOperationsEnabled && len(sqlServerOperations) > 0 { - m.console.Message(ctx, ErrSqlScriptOperationDisabled.Error()) - } - if azdOperationsEnabled { - if err != nil { - return nil, fmt.Errorf("looking for azd sql scripts operations: %w", err) - } - if err := doSqlScriptOperation( - ctx, sqlServerOperations, m.console, *m.env, m.sqlCmdCli); err != nil { - return nil, err - } - } - // make sure any spinner is stopped m.console.StopSpinner(ctx, "", input.StepDone) return deployResult, nil } -const ( - fileShareUploadOperation string = "FileShareUpload" - sqlServerOperation string = "SqlScript" - azdOperationsFileName string = "azd.operations.yaml" -) - -type azdOperation struct { - Type string - Description string - Config any -} - -type azdOperationFileShareUpload struct { - Description string - StorageAccount string - FileShareName string - Path string -} - -type azdOperationSqlServer struct { - Description string - Server string - Database string - Path string - Env map[string]string -} - -type azdOperationsModel struct { - Operations []azdOperation -} - -func azdOperations(infraPath string, env environment.Environment) (azdOperationsModel, error) { - path := filepath.Join(infraPath, azdOperationsFileName) - data, err := os.ReadFile(path) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - // file not found is not an error, there's just nothing to do - return azdOperationsModel{}, nil - } - return azdOperationsModel{}, err - } - - // resolve environment variables - expString := osutil.NewExpandableString(string(data)) - evaluated, err := expString.Envsubst(env.Getenv) - if err != nil { - return azdOperationsModel{}, err - } - data = []byte(evaluated) - - // Unmarshal the file into azdOperationsModel - var operations azdOperationsModel - err = yaml.Unmarshal(data, &operations) - if err != nil { - return azdOperationsModel{}, err - } - - return operations, nil -} - -func azdFileShareUploadOperations(model azdOperationsModel) ([]azdOperationFileShareUpload, error) { - var fileShareUploadOperations []azdOperationFileShareUpload - for _, operation := range model.Operations { - if operation.Type == fileShareUploadOperation { - var fileShareUpload azdOperationFileShareUpload - bytes, err := json.Marshal(operation.Config) - if err != nil { - return nil, err - } - err = json.Unmarshal(bytes, &fileShareUpload) - if err != nil { - return nil, err - } - fileShareUpload.Description = operation.Description - fileShareUploadOperations = append(fileShareUploadOperations, fileShareUpload) - } - } - return fileShareUploadOperations, nil -} - -func azdSqlServerOperations(model azdOperationsModel, infraPath string) ([]azdOperationSqlServer, error) { - var sqlServerOperations []azdOperationSqlServer - for _, operation := range model.Operations { - if operation.Type == sqlServerOperation { - var sqlServerScript azdOperationSqlServer - bytes, err := json.Marshal(operation.Config) - if err != nil { - return nil, err - } - err = json.Unmarshal(bytes, &sqlServerScript) - if err != nil { - return nil, err - } - sqlServerScript.Description = operation.Description - if !filepath.IsAbs(sqlServerScript.Path) { - sqlServerScript.Path = filepath.Join(infraPath, sqlServerScript.Path) - } - sqlServerOperations = append(sqlServerOperations, sqlServerScript) - } - } - return sqlServerOperations, nil -} - -var ErrAzdOperationsNotEnabled = fmt.Errorf(fmt.Sprintf( - "azd operations (alpha feature) is required but disabled. You can enable azd operations by running: %s", - output.WithGrayFormat(alpha.GetEnableCommand(AzdOperationsFeatureKey)))) - -var ErrBindMountOperationDisabled = fmt.Errorf( - "%sYour project has bind mounts.\n - %w\n%s\n", - output.WithWarningFormat("*Note: "), - ErrAzdOperationsNotEnabled, - output.WithWarningFormat("Ignoring bind mounts."), -) - -var ErrSqlScriptOperationDisabled = fmt.Errorf( - "%sYour project has sql server scripts.\n - %w\n%s\n", - output.WithWarningFormat("*Note: "), - ErrAzdOperationsNotEnabled, - output.WithWarningFormat("Ignoring scripts."), -) - -func doBindMountOperation( - ctx context.Context, - fileShareUploadOperations []azdOperationFileShareUpload, - env environment.Environment, - console input.Console, - fileShareService storage.FileShareService, - cloudStorageEndpointSuffix string, -) error { - if len(fileShareUploadOperations) > 0 { - console.ShowSpinner(ctx, "uploading files to fileShare", input.StepFailed) - } - for _, op := range fileShareUploadOperations { - if err := bindMountOperation( - ctx, - fileShareService, - cloudStorageEndpointSuffix, - env.GetSubscriptionId(), - op.StorageAccount, - op.FileShareName, - op.Path); err != nil { - return fmt.Errorf("error binding mount: %w", err) - } - console.MessageUxItem(ctx, &ux.DisplayedResource{ - Type: fileShareUploadOperation, - Name: op.Description, - State: ux.SucceededState, - }) - } - return nil -} - -func bindMountOperation( - ctx context.Context, - fileShareService storage.FileShareService, - cloud, subId, storageAccount, fileShareName, source string) error { - - shareUrl := fmt.Sprintf("https://%s.file.%s/%s", storageAccount, cloud, fileShareName) - return fileShareService.UploadPath(ctx, subId, shareUrl, source) -} - -func doSqlScriptOperation( - ctx context.Context, - SqlScriptsOperations []azdOperationSqlServer, - console input.Console, - env environment.Environment, - sqlCmdCli *sqlcmd.SqlCmdCli, -) error { - if len(SqlScriptsOperations) > 0 { - console.ShowSpinner(ctx, "execute sql scripts", input.StepFailed) - } - for _, op := range SqlScriptsOperations { - filePath := op.Path - if op.Env != nil { - fileEnv := environment.NewWithValues("fileEnv", op.Env) - tmpDir, err := os.MkdirTemp("", "azd-sql-scripts") - if err != nil { - return err - } - defer os.RemoveAll(tmpDir) - data, err := os.ReadFile(filePath) - if err != nil { - return err - } - expString := osutil.NewExpandableString(string(data)) - evaluated, err := expString.Envsubst(fileEnv.Getenv) - if err != nil { - return err - } - filePath = filepath.Join(tmpDir, filepath.Base(filePath)) - err = os.WriteFile(filePath, []byte(evaluated), osutil.PermissionDirectory) - if err != nil { - return err - } - } - - if _, err := sqlCmdCli.ExecuteScript( - ctx, - op.Server, - op.Database, - filePath, - // sqlCmd cli uses DAC to connect to the server, but it doesn't know how to handle multi-tenant accounts. - // sqlCmd cli asks az or azd for a token w/o passing a tenant-id arg. - // sqlCmd cli runs from ~/.azd/bin: - // - azd doesn't know the tenant-id to use and defaults to get a token for home tenant. - // By setting the AZURE_SUBSCRIPTION_ID as env var to run sqlCmd cli, azd will use it to get tenant-id. - []string{ - fmt.Sprintf("%s=%s", environment.SubscriptionIdEnvVarName, env.GetSubscriptionId()), - }, - ); err != nil { - return fmt.Errorf("error run sqlcmd: %w", err) - } - console.MessageUxItem(ctx, &ux.DisplayedResource{ - Type: sqlServerOperation, - Name: op.Description, - State: ux.SucceededState, - }) - } - return nil -} - // Preview generates the list of changes to be applied as part of the provisioning. func (m *Manager) Preview(ctx context.Context) (*DeployPreviewResult, error) { // Apply the infrastructure deployment diff --git a/cli/azd/pkg/infra/provisioning/operations/azd_operation.go b/cli/azd/pkg/infra/provisioning/operations/azd_operation.go new file mode 100644 index 00000000000..079d0f4be57 --- /dev/null +++ b/cli/azd/pkg/infra/provisioning/operations/azd_operation.go @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package operations + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/azure/azure-dev/cli/azd/pkg/alpha" + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/osutil" + "github.com/azure/azure-dev/cli/azd/pkg/output" + "gopkg.in/yaml.v3" +) + +type azdOperation struct { + Type string + Description string + Config any +} + +type AzdOperationsModel struct { + Operations []azdOperation +} + +const ( + fileShareUploadOperation string = "FileShareUpload" + sqlServerOperation string = "SqlScript" + azdOperationsFileName string = "azd.operations.yaml" +) + +var AzdOperationsFeatureKey = alpha.MustFeatureKey("azd.operations") + +var ErrAzdOperationsNotEnabled = fmt.Errorf(fmt.Sprintf( + "azd operations (alpha feature) is required but disabled. You can enable azd operations by running: %s", + output.WithGrayFormat(alpha.GetEnableCommand(AzdOperationsFeatureKey)))) + +func AzdOperations(infraPath string, env environment.Environment) (AzdOperationsModel, error) { + path := filepath.Join(infraPath, azdOperationsFileName) + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + // file not found is not an error, there's just nothing to do + return AzdOperationsModel{}, nil + } + return AzdOperationsModel{}, err + } + + // resolve environment variables + expString := osutil.NewExpandableString(string(data)) + evaluated, err := expString.Envsubst(env.Getenv) + if err != nil { + return AzdOperationsModel{}, err + } + data = []byte(evaluated) + + // Unmarshal the file into azdOperationsModel + var operations AzdOperationsModel + err = yaml.Unmarshal(data, &operations) + if err != nil { + return AzdOperationsModel{}, err + } + + return operations, nil +} diff --git a/cli/azd/pkg/infra/provisioning/operations/file_share_upload.go b/cli/azd/pkg/infra/provisioning/operations/file_share_upload.go new file mode 100644 index 00000000000..82e24958405 --- /dev/null +++ b/cli/azd/pkg/infra/provisioning/operations/file_share_upload.go @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package operations + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/azure/azure-dev/cli/azd/pkg/azsdk/storage" + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/azure/azure-dev/cli/azd/pkg/output/ux" +) + +type FileShareUpload struct { + Description string + StorageAccount string + FileShareName string + Path string +} + +var ErrBindMountOperationDisabled = fmt.Errorf( + "%sYour project has bind mounts.\n - %w\n%s\n", + output.WithWarningFormat("*Note: "), + ErrAzdOperationsNotEnabled, + output.WithWarningFormat("Ignoring bind mounts."), +) + +func FileShareUploads(model AzdOperationsModel) ([]FileShareUpload, error) { + var fileShareUploadOperations []FileShareUpload + for _, operation := range model.Operations { + if operation.Type == fileShareUploadOperation { + var fileShareUpload FileShareUpload + bytes, err := json.Marshal(operation.Config) + if err != nil { + return nil, err + } + err = json.Unmarshal(bytes, &fileShareUpload) + if err != nil { + return nil, err + } + fileShareUpload.Description = operation.Description + fileShareUploadOperations = append(fileShareUploadOperations, fileShareUpload) + } + } + return fileShareUploadOperations, nil +} + +func DoBindMount( + ctx context.Context, + fileShareUploadOperations []FileShareUpload, + env *environment.Environment, + console input.Console, + fileShareService storage.FileShareService, + cloudStorageEndpointSuffix string, +) error { + if len(fileShareUploadOperations) > 0 { + console.ShowSpinner(ctx, "uploading files to fileShare", input.StepFailed) + } + for _, op := range fileShareUploadOperations { + if err := bindMountOperation( + ctx, + fileShareService, + cloudStorageEndpointSuffix, + env.GetSubscriptionId(), + op.StorageAccount, + op.FileShareName, + op.Path); err != nil { + return fmt.Errorf("error binding mount: %w", err) + } + console.MessageUxItem(ctx, &ux.DisplayedResource{ + Type: fileShareUploadOperation, + Name: op.Description, + State: ux.SucceededState, + }) + } + return nil +} + +func bindMountOperation( + ctx context.Context, + fileShareService storage.FileShareService, + cloud, subId, storageAccount, fileShareName, source string) error { + + shareUrl := fmt.Sprintf("https://%s.file.%s/%s", storageAccount, cloud, fileShareName) + return fileShareService.UploadPath(ctx, subId, shareUrl, source) +} diff --git a/cli/azd/pkg/infra/provisioning/operations/sql_script.go b/cli/azd/pkg/infra/provisioning/operations/sql_script.go new file mode 100644 index 00000000000..e053a1d5d1c --- /dev/null +++ b/cli/azd/pkg/infra/provisioning/operations/sql_script.go @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package operations + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/osutil" + "github.com/azure/azure-dev/cli/azd/pkg/output" + "github.com/azure/azure-dev/cli/azd/pkg/output/ux" + "github.com/azure/azure-dev/cli/azd/pkg/tools/sqlcmd" +) + +type SqlScript struct { + Description string + Server string + Database string + Path string + Env map[string]string +} + +var ErrSqlScriptOperationDisabled = fmt.Errorf( + "%sYour project has sql server scripts.\n - %w\n%s\n", + output.WithWarningFormat("*Note: "), + ErrAzdOperationsNotEnabled, + output.WithWarningFormat("Ignoring scripts."), +) + +func SqlScripts(model AzdOperationsModel, infraPath string) ([]SqlScript, error) { + var sqlServerOperations []SqlScript + for _, operation := range model.Operations { + if operation.Type == sqlServerOperation { + var sqlServerScript SqlScript + bytes, err := json.Marshal(operation.Config) + if err != nil { + return nil, err + } + err = json.Unmarshal(bytes, &sqlServerScript) + if err != nil { + return nil, err + } + sqlServerScript.Description = operation.Description + if !filepath.IsAbs(sqlServerScript.Path) { + sqlServerScript.Path = filepath.Join(infraPath, sqlServerScript.Path) + } + sqlServerOperations = append(sqlServerOperations, sqlServerScript) + } + } + return sqlServerOperations, nil +} + +func DoSqlScript( + ctx context.Context, + SqlScriptsOperations []SqlScript, + console input.Console, + env environment.Environment, + sqlCmdCli *sqlcmd.SqlCmdCli, +) error { + if len(SqlScriptsOperations) > 0 { + console.ShowSpinner(ctx, "execute sql scripts", input.StepFailed) + } + for _, op := range SqlScriptsOperations { + filePath := op.Path + if op.Env != nil { + fileEnv := environment.NewWithValues("fileEnv", op.Env) + tmpDir, err := os.MkdirTemp("", "azd-sql-scripts") + if err != nil { + return err + } + defer os.RemoveAll(tmpDir) + data, err := os.ReadFile(filePath) + if err != nil { + return err + } + expString := osutil.NewExpandableString(string(data)) + evaluated, err := expString.Envsubst(fileEnv.Getenv) + if err != nil { + return err + } + filePath = filepath.Join(tmpDir, filepath.Base(filePath)) + err = os.WriteFile(filePath, []byte(evaluated), osutil.PermissionDirectory) + if err != nil { + return err + } + } + + if _, err := sqlCmdCli.ExecuteScript( + ctx, + op.Server, + op.Database, + filePath, + // sqlCmd cli uses DAC to connect to the server, but it doesn't know how to handle multi-tenant accounts. + // sqlCmd cli asks az or azd for a token w/o passing a tenant-id arg. + // sqlCmd cli runs from ~/.azd/bin: + // - azd doesn't know the tenant-id to use and defaults to get a token for home tenant. + // By setting the AZURE_SUBSCRIPTION_ID as env var to run sqlCmd cli, azd will use it to get tenant-id. + []string{ + fmt.Sprintf("%s=%s", environment.SubscriptionIdEnvVarName, env.GetSubscriptionId()), + }, + ); err != nil { + return fmt.Errorf("error run sqlcmd: %w", err) + } + console.MessageUxItem(ctx, &ux.DisplayedResource{ + Type: sqlServerOperation, + Name: op.Description, + State: ux.SucceededState, + }) + } + return nil +} diff --git a/cli/azd/pkg/project/dotnet_importer.go b/cli/azd/pkg/project/dotnet_importer.go index f33e1b62879..d962dca44ff 100644 --- a/cli/azd/pkg/project/dotnet_importer.go +++ b/cli/azd/pkg/project/dotnet_importer.go @@ -17,6 +17,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/environment" "github.com/azure/azure-dev/cli/azd/pkg/ext" "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" + "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning/operations" "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/lazy" "github.com/azure/azure-dev/cli/azd/pkg/osutil" @@ -107,12 +108,12 @@ func (ai *DotNetImporter) ProjectInfrastructure(ctx context.Context, svcConfig * return nil, fmt.Errorf("generating app host manifest: %w", err) } - azdOperationsEnabled := ai.alphaFeatureManager.IsEnabled(provisioning.AzdOperationsFeatureKey) + azdOperationsEnabled := ai.alphaFeatureManager.IsEnabled(operations.AzdOperationsFeatureKey) files, err := apphost.BicepTemplate("main", manifest, apphost.AppHostOptions{ AzdOperations: azdOperationsEnabled, }) if err != nil { - if errors.Is(err, provisioning.ErrAzdOperationsNotEnabled) { + if errors.Is(err, operations.ErrAzdOperationsNotEnabled) { // Use a warning for this error about azd operations is required for the current project to fully work ai.console.Message(ctx, err.Error()) } else { @@ -457,12 +458,12 @@ func (ai *DotNetImporter) SynthAllInfrastructure( rootModuleName = p.Infra.Module } - azdOperationsEnabled := ai.alphaFeatureManager.IsEnabled(provisioning.AzdOperationsFeatureKey) + azdOperationsEnabled := ai.alphaFeatureManager.IsEnabled(operations.AzdOperationsFeatureKey) infraFS, err := apphost.BicepTemplate(rootModuleName, manifest, apphost.AppHostOptions{ AzdOperations: azdOperationsEnabled, }) if err != nil { - if errors.Is(err, provisioning.ErrAzdOperationsNotEnabled) { + if errors.Is(err, operations.ErrAzdOperationsNotEnabled) { // Use a warning for this error about azd operations is required for the current project to fully work ai.console.Message(ctx, err.Error()) } else { From 29ac8203930c57e1f0cda3ee15b7d37fbb579963 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Wed, 24 Jul 2024 23:42:54 +0000 Subject: [PATCH 09/12] docs --- cli/azd/internal/cmd/provision.go | 5 +++- cli/azd/pkg/infra/provisioning/manager.go | 2 +- .../provisioning/operations/azd_operation.go | 5 ++++ .../operations/file_share_upload.go | 27 +++++++------------ .../provisioning/operations/sql_script.go | 4 +++ 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/cli/azd/internal/cmd/provision.go b/cli/azd/internal/cmd/provision.go index 1c186d68580..37c5c52a230 100644 --- a/cli/azd/internal/cmd/provision.go +++ b/cli/azd/internal/cmd/provision.go @@ -191,7 +191,10 @@ func (p *ProvisionAction) Run(ctx context.Context) (*actions.ActionResult, error return nil, fmt.Errorf("initializing provisioning manager: %w", err) } - // register operations + // ** Registering post-provisioning operations ** + // When azd.operations.yaml is found, the provisioning manager returns the list of operations to be executed + // as callbacks -> []func(ctx context.Context) error, error) + // See package `infra/provisioning/operations` for more details. operations, err := p.provisionManager.Operations(ctx) if err != nil { return nil, fmt.Errorf("registering operations: %w", err) diff --git a/cli/azd/pkg/infra/provisioning/manager.go b/cli/azd/pkg/infra/provisioning/manager.go index e128676a8b6..d01e55f1ee3 100644 --- a/cli/azd/pkg/infra/provisioning/manager.go +++ b/cli/azd/pkg/infra/provisioning/manager.go @@ -66,7 +66,7 @@ func (m *Manager) Operations(ctx context.Context) ([]func(ctx context.Context) e return result, fmt.Errorf("looking for azd fileShare upload operations: %w", err) } result = append(result, func(context context.Context) error { - return operations.DoBindMount( + return operations.DoFileShareUpload( context, bindMountOperations, m.env, m.console, m.fileShareService, m.cloud.StorageEndpointSuffix) }) } diff --git a/cli/azd/pkg/infra/provisioning/operations/azd_operation.go b/cli/azd/pkg/infra/provisioning/operations/azd_operation.go index 079d0f4be57..82dbb34f96b 100644 --- a/cli/azd/pkg/infra/provisioning/operations/azd_operation.go +++ b/cli/azd/pkg/infra/provisioning/operations/azd_operation.go @@ -16,12 +16,14 @@ import ( "gopkg.in/yaml.v3" ) +// azdOperation represents an operation that can be performed by the azd. type azdOperation struct { Type string Description string Config any } +// AzdOperationsModel is the abstraction of azd.operations.yaml file. It is used to unmarshal the yaml file into a struct. type AzdOperationsModel struct { Operations []azdOperation } @@ -32,12 +34,15 @@ const ( azdOperationsFileName string = "azd.operations.yaml" ) +// AzdOperationsFeatureKey is the alpha feature key for azd operations. var AzdOperationsFeatureKey = alpha.MustFeatureKey("azd.operations") +// ErrAzdOperationsNotEnabled is returned when azd operations are not enabled. var ErrAzdOperationsNotEnabled = fmt.Errorf(fmt.Sprintf( "azd operations (alpha feature) is required but disabled. You can enable azd operations by running: %s", output.WithGrayFormat(alpha.GetEnableCommand(AzdOperationsFeatureKey)))) +// AzdOperations returns the azd operations from the azd.operations.yaml file. func AzdOperations(infraPath string, env environment.Environment) (AzdOperationsModel, error) { path := filepath.Join(infraPath, azdOperationsFileName) data, err := os.ReadFile(path) diff --git a/cli/azd/pkg/infra/provisioning/operations/file_share_upload.go b/cli/azd/pkg/infra/provisioning/operations/file_share_upload.go index 82e24958405..9824174a2e2 100644 --- a/cli/azd/pkg/infra/provisioning/operations/file_share_upload.go +++ b/cli/azd/pkg/infra/provisioning/operations/file_share_upload.go @@ -15,6 +15,8 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/output/ux" ) +// FileShareUpload defines the configuration for a file share upload operation. +// When the operation is executed, the files in the specified path are uploaded to the specified file share. type FileShareUpload struct { Description string StorageAccount string @@ -22,6 +24,7 @@ type FileShareUpload struct { Path string } +// ErrBindMountOperationDisabled is returned when bind mount operations are disabled. var ErrBindMountOperationDisabled = fmt.Errorf( "%sYour project has bind mounts.\n - %w\n%s\n", output.WithWarningFormat("*Note: "), @@ -29,6 +32,7 @@ var ErrBindMountOperationDisabled = fmt.Errorf( output.WithWarningFormat("Ignoring bind mounts."), ) +// FileShareUploads returns the file share upload operations (if any) from the azd operations model. func FileShareUploads(model AzdOperationsModel) ([]FileShareUpload, error) { var fileShareUploadOperations []FileShareUpload for _, operation := range model.Operations { @@ -49,7 +53,9 @@ func FileShareUploads(model AzdOperationsModel) ([]FileShareUpload, error) { return fileShareUploadOperations, nil } -func DoBindMount( +// DoFileShareUpload performs the bind mount operations. +// It uploads the files in the specified path to the specified file share. +func DoFileShareUpload( ctx context.Context, fileShareUploadOperations []FileShareUpload, env *environment.Environment, @@ -61,14 +67,8 @@ func DoBindMount( console.ShowSpinner(ctx, "uploading files to fileShare", input.StepFailed) } for _, op := range fileShareUploadOperations { - if err := bindMountOperation( - ctx, - fileShareService, - cloudStorageEndpointSuffix, - env.GetSubscriptionId(), - op.StorageAccount, - op.FileShareName, - op.Path); err != nil { + shareUrl := fmt.Sprintf("https://%s.file.%s/%s", op.StorageAccount, cloudStorageEndpointSuffix, op.FileShareName) + if err := fileShareService.UploadPath(ctx, env.GetSubscriptionId(), shareUrl, op.Path); err != nil { return fmt.Errorf("error binding mount: %w", err) } console.MessageUxItem(ctx, &ux.DisplayedResource{ @@ -79,12 +79,3 @@ func DoBindMount( } return nil } - -func bindMountOperation( - ctx context.Context, - fileShareService storage.FileShareService, - cloud, subId, storageAccount, fileShareName, source string) error { - - shareUrl := fmt.Sprintf("https://%s.file.%s/%s", storageAccount, cloud, fileShareName) - return fileShareService.UploadPath(ctx, subId, shareUrl, source) -} diff --git a/cli/azd/pkg/infra/provisioning/operations/sql_script.go b/cli/azd/pkg/infra/provisioning/operations/sql_script.go index e053a1d5d1c..e13be998329 100644 --- a/cli/azd/pkg/infra/provisioning/operations/sql_script.go +++ b/cli/azd/pkg/infra/provisioning/operations/sql_script.go @@ -18,6 +18,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/tools/sqlcmd" ) +// SqlScript defines the configuration for a sql script operation. type SqlScript struct { Description string Server string @@ -26,6 +27,7 @@ type SqlScript struct { Env map[string]string } +// ErrSqlScriptOperationDisabled is returned when sql script operations are disabled. var ErrSqlScriptOperationDisabled = fmt.Errorf( "%sYour project has sql server scripts.\n - %w\n%s\n", output.WithWarningFormat("*Note: "), @@ -33,6 +35,7 @@ var ErrSqlScriptOperationDisabled = fmt.Errorf( output.WithWarningFormat("Ignoring scripts."), ) +// SqlScripts returns the sql script operations (if any) from the azd operations model. func SqlScripts(model AzdOperationsModel, infraPath string) ([]SqlScript, error) { var sqlServerOperations []SqlScript for _, operation := range model.Operations { @@ -56,6 +59,7 @@ func SqlScripts(model AzdOperationsModel, infraPath string) ([]SqlScript, error) return sqlServerOperations, nil } +// DoSqlScript performs the sql script operations. func DoSqlScript( ctx context.Context, SqlScriptsOperations []SqlScript, From 2e6d250fc72d41158e865bda5d8d0324b2cc85c0 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Wed, 24 Jul 2024 23:46:53 +0000 Subject: [PATCH 10/12] more docs --- cli/azd/pkg/infra/provisioning/manager.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cli/azd/pkg/infra/provisioning/manager.go b/cli/azd/pkg/infra/provisioning/manager.go index d01e55f1ee3..a6cc03067e6 100644 --- a/cli/azd/pkg/infra/provisioning/manager.go +++ b/cli/azd/pkg/infra/provisioning/manager.go @@ -45,6 +45,10 @@ const ( defaultPath = "infra" ) +// Operations checks if a file azd.operations.yaml exists in the infraPath and returns the list of operations +// defined in the file. Operations are grouped by type and wrapped in a function that can be executed. +// The result is a list of functions where each function represents an operations group type. +// The operations can be registered as project lifecycle events, for example, as post-provisioning operations. func (m *Manager) Operations(ctx context.Context) ([]func(ctx context.Context) error, error) { //Get a list of operations result := []func(ctx context.Context) error{} From 9b4efe15c9a1837d8e4074b7bfb59302149350ea Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Thu, 25 Jul 2024 00:02:43 +0000 Subject: [PATCH 11/12] support ipAddress to --- cli/azd/pkg/azure/arm_template.go | 1 + cli/azd/pkg/httputil/util.go | 16 ++++++++++++++++ cli/azd/pkg/httputil/util_test.go | 10 ++++++++++ .../infra/provisioning/bicep/bicep_provider.go | 11 +++++++++++ 4 files changed, 38 insertions(+) diff --git a/cli/azd/pkg/azure/arm_template.go b/cli/azd/pkg/azure/arm_template.go index a8f4c4fc79d..a0f280dc678 100644 --- a/cli/azd/pkg/azure/arm_template.go +++ b/cli/azd/pkg/azure/arm_template.go @@ -107,6 +107,7 @@ const AzdMetadataTypeGenerate AzdMetadataType = "generate" const AzdMetadataTypePrincipalLogin AzdMetadataType = "principalLogin" const AzdMetadataTypePrincipalId AzdMetadataType = "principalId" const AzdMetadataTypePrincipalType AzdMetadataType = "principalType" +const AzdMetadataTypeIpAddress AzdMetadataType = "ipAddress" const AzdMetadataTypeGenerateOrManual AzdMetadataType = "generateOrManual" type AzdMetadata struct { diff --git a/cli/azd/pkg/httputil/util.go b/cli/azd/pkg/httputil/util.go index aff0f73010c..f412a482a19 100644 --- a/cli/azd/pkg/httputil/util.go +++ b/cli/azd/pkg/httputil/util.go @@ -119,3 +119,19 @@ func RetryAfter(resp *http.Response) time.Duration { return 0 } + +// GetIpAddress returns the public IP address of the caller. +func GetIpAddress() (string, error) { + resp, err := http.Get("https://api.ipify.org") + if err != nil { + return "", err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + return string(data), nil +} diff --git a/cli/azd/pkg/httputil/util_test.go b/cli/azd/pkg/httputil/util_test.go index c0b9aa13f90..fe3f1baf37a 100644 --- a/cli/azd/pkg/httputil/util_test.go +++ b/cli/azd/pkg/httputil/util_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "io" + "net" "net/http" "testing" @@ -38,3 +39,12 @@ func TestReadRawResponse(t *testing.T) { require.Equal(t, *expectedResponse, *actualResponse) }) } + +func TestGetIpAddress(t *testing.T) { + ip, err := GetIpAddress() + + require.NoError(t, err) + require.NotEmpty(t, ip) + validIp := net.ParseIP(ip) + require.NotNil(t, validIp) +} diff --git a/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go b/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go index a8a52a9c743..e62f1824899 100644 --- a/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go +++ b/cli/azd/pkg/infra/provisioning/bicep/bicep_provider.go @@ -32,6 +32,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/config" "github.com/azure/azure-dev/cli/azd/pkg/convert" "github.com/azure/azure-dev/cli/azd/pkg/environment" + "github.com/azure/azure-dev/cli/azd/pkg/httputil" "github.com/azure/azure-dev/cli/azd/pkg/infra" . "github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning" "github.com/azure/azure-dev/cli/azd/pkg/input" @@ -2053,6 +2054,16 @@ func (p *BicepProvider) ensureParameters( mustSetParamAsConfig(key, pLogin, p.env.Config, param.Secure()) configModified = true continue + case azure.AzdMetadataTypeIpAddress: + ipAddress, err := httputil.GetIpAddress() + if err != nil { + return nil, fmt.Errorf("getting IP address for bicep parameter: %w", err) + } + configuredParameters[key] = azure.ArmParameterValue{ + Value: ipAddress, + } + // this metadata type is not saved to config as the IP can be dynamic. + continue default: // Do nothing log.Println("Skipping actions for azd unknown metadata bicep parameter with type: ", azdMetadataType) From 10f43da649ac115cbcac40bf40cc961e74761387 Mon Sep 17 00:00:00 2001 From: Victor Vazquez Date: Fri, 23 Aug 2024 05:33:20 +0000 Subject: [PATCH 12/12] remove event --- cli/azd/pkg/tools/sqlcmd/sqlcmd.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cli/azd/pkg/tools/sqlcmd/sqlcmd.go b/cli/azd/pkg/tools/sqlcmd/sqlcmd.go index 6cf2a921f7e..bd0e5273171 100644 --- a/cli/azd/pkg/tools/sqlcmd/sqlcmd.go +++ b/cli/azd/pkg/tools/sqlcmd/sqlcmd.go @@ -19,8 +19,6 @@ import ( "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/azure/azure-dev/cli/azd/internal/tracing" - "github.com/azure/azure-dev/cli/azd/internal/tracing/events" "github.com/azure/azure-dev/cli/azd/pkg/config" "github.com/azure/azure-dev/cli/azd/pkg/exec" "github.com/azure/azure-dev/cli/azd/pkg/input" @@ -312,10 +310,7 @@ func downloadSqlCmd( log.Printf("downloading sqlCmd cli release %s -> %s", sqlCmdReleaseUrl, releaseName) - spanCtx, span := tracing.Start(ctx, events.SqlCmdCliInstallEvent) - defer span.End() - - req, err := http.NewRequestWithContext(spanCtx, "GET", sqlCmdReleaseUrl, nil) + req, err := http.NewRequestWithContext(ctx, "GET", sqlCmdReleaseUrl, nil) if err != nil { return err }