diff --git a/cli/azd/cmd/auto_install.go b/cli/azd/cmd/auto_install.go new file mode 100644 index 00000000000..e39d08eab6b --- /dev/null +++ b/cli/azd/cmd/auto_install.go @@ -0,0 +1,406 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "log" + "os" + "strings" + + "github.com/azure/azure-dev/cli/azd/internal/tracing/resource" + "github.com/azure/azure-dev/cli/azd/pkg/alpha" + "github.com/azure/azure-dev/cli/azd/pkg/extensions" + "github.com/azure/azure-dev/cli/azd/pkg/input" + "github.com/azure/azure-dev/cli/azd/pkg/ioc" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// extractFlagsWithValues extracts flags that take values from a cobra command. +// This ensures we have a single source of truth for flag definitions by +// dynamically inspecting the command's flag definitions instead of +// maintaining a separate hardcoded list. +// +// The function inspects both regular flags and persistent flags, checking +// the flag's value type to determine if it takes an argument: +// - Bool flags don't take values +// - String, Int, StringSlice, etc. flags do take values +func extractFlagsWithValues(cmd *cobra.Command) map[string]bool { + flagsWithValues := make(map[string]bool) + + // Extract flags that take values from the command + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + // String, StringSlice, StringArray, Int, Int64, etc. all take values + // Bool flags don't take values + if flag.Value.Type() != "bool" { + flagsWithValues["--"+flag.Name] = true + if flag.Shorthand != "" { + flagsWithValues["-"+flag.Shorthand] = true + } + } + }) + + // Also check persistent flags (global flags) + // IMPORTANT: cmd.Flags().VisitAll() does NOT include persistent flags. + // In Cobra, cmd.Flags() only returns local flags specific to that command, + // while cmd.PersistentFlags() returns flags that are inherited by subcommands. + // These are separate flag sets, so we must call both VisitAll functions + // to capture all flags that can take values. + cmd.PersistentFlags().VisitAll(func(flag *pflag.Flag) { + if flag.Value.Type() != "bool" { + flagsWithValues["--"+flag.Name] = true + if flag.Shorthand != "" { + flagsWithValues["-"+flag.Shorthand] = true + } + } + }) + + return flagsWithValues +} + +// findFirstNonFlagArg finds the first argument that doesn't start with '-' and isn't a flag value. +// This function properly handles flags that take values (like --output json) to avoid +// incorrectly identifying flag values as commands. +// Returns the command and any unknown flags encountered before the command. +func findFirstNonFlagArg(args []string, flagsWithValues map[string]bool) (command string, unknownFlags []string) { + // Initialize as empty slice instead of nil for consistent behavior + unknownFlags = []string{} + + skipNext := false + for i, arg := range args { + // Skip this argument if it's marked as a flag value from previous iteration + if skipNext { + skipNext = false + continue + } + + // If it doesn't start with '-', it's a potential command + if !strings.HasPrefix(arg, "-") { + return arg, unknownFlags + } + + // Check if this is a known flag that takes a value + if flagsWithValues[arg] { + // This flag takes a value, so skip the next argument + skipNext = true + continue + } + + // Handle flags with '=' syntax like --output=json + if strings.Contains(arg, "=") { + parts := strings.SplitN(arg, "=", 2) + if flagsWithValues[parts[0]] { + // This is a known flag=value format, no need to skip next + continue + } + // Unknown flag with equals - record it + unknownFlags = append(unknownFlags, parts[0]) + continue + } + + // This is an unknown flag - record it + unknownFlags = append(unknownFlags, arg) + + // Conservative heuristic: if the next argument doesn't start with '-' + // and there are more args after it, assume the unknown flag takes a value + if i+1 < len(args) && i+2 < len(args) { + nextArg := args[i+1] + argAfterNext := args[i+2] + if !strings.HasPrefix(nextArg, "-") && !strings.HasPrefix(argAfterNext, "-") { + // Pattern: --unknown value command + // Skip the value, let command be found next + skipNext = true + } + } + } + + return "", unknownFlags +} + +// checkForMatchingExtensions checks for extensions that match any possible namespace +// from the command arguments. For example, "azd foo demo bar" will check for +// extensions with namespaces: "foo", "foo.demo", "foo.demo.bar" +func checkForMatchingExtensions( + ctx context.Context, extensionManager *extensions.Manager, args []string) ([]*extensions.ExtensionMetadata, error) { + if len(args) == 0 { + return nil, nil + } + + options := &extensions.ListOptions{} + registryExtensions, err := extensionManager.ListFromRegistry(ctx, options) + if err != nil { + return nil, err + } + + var matchingExtensions []*extensions.ExtensionMetadata + + // Generate all possible namespace combinations from the command arguments + // For "azd something demo foo" -> check "something", "something.demo", "something.demo.foo" + for i := 1; i <= len(args); i++ { + candidateNamespace := strings.Join(args[:i], ".") + + // Check if any extension has this exact namespace + for _, ext := range registryExtensions { + if ext.Namespace == candidateNamespace { + matchingExtensions = append(matchingExtensions, ext) + } + } + } + + return matchingExtensions, nil +} + +// promptForExtensionChoice prompts the user to choose from multiple matching extensions +func promptForExtensionChoice( + ctx context.Context, + console input.Console, + extensions []*extensions.ExtensionMetadata) (*extensions.ExtensionMetadata, error) { + + if len(extensions) == 0 { + return nil, nil + } + + if len(extensions) == 1 { + return extensions[0], nil + } + + console.Message(ctx, "Multiple extensions found that match your command:") + console.Message(ctx, "") + + options := make([]string, len(extensions)) + for i, ext := range extensions { + options[i] = fmt.Sprintf("%s (%s) - %s", ext.Namespace, ext.DisplayName, ext.Description) + console.Message(ctx, fmt.Sprintf(" %d. %s", i+1, options[i])) + } + console.Message(ctx, "") + + choice, err := console.Select(ctx, input.ConsoleOptions{ + Message: "Which extension would you like to install?", + Options: options, + }) + if err != nil { + return nil, err + } + + return extensions[choice], nil +} + +// isBuiltInCommand checks if the given command is a built-in command by examining +// the root command's command tree. This includes both core azd commands and any +// installed extensions, preventing auto-install from triggering for known commands. +func isBuiltInCommand(rootCmd *cobra.Command, commandName string) bool { + if commandName == "" { + return false + } + + // Check if the command exists in the root command's subcommands + for _, cmd := range rootCmd.Commands() { + if cmd.Name() == commandName { + return true + } + // Also check aliases + for _, alias := range cmd.Aliases { + if alias == commandName { + return true + } + } + } + + return false +} + +// tryAutoInstallExtension attempts to auto-install an extension if the unknown command matches an available +// extension namespace. Returns true if an extension was found and installed, false otherwise. +func tryAutoInstallExtension( + ctx context.Context, + console input.Console, + extensionManager *extensions.Manager, + extension extensions.ExtensionMetadata) (bool, error) { + + // Check if the extension is already installed + _, err := extensionManager.GetInstalled(extensions.LookupOptions{ + Id: extension.Id, + }) + if err == nil { + return false, nil + } + + // Return error if running in CI/CD environment + if resource.IsRunningOnCI() { + return false, + fmt.Errorf( + "Command '%s' not found, but there's an available extension that provides it.\n"+ + "However, auto-installation is not supported in CI/CD environments.\n"+ + "Run '%s' to install it manually.", + extension.Namespace, + fmt.Sprintf("azd extension install %s", extension.Id)) + } + + // Ask user for permission to auto-install the extension + console.Message(ctx, + fmt.Sprintf("Command '%s' not found, but there's an available extension that provides it.\n", extension.Namespace)) + console.Message(ctx, + fmt.Sprintf("Extension: %s (%s)\n", extension.DisplayName, extension.Description)) + shouldInstall, err := console.Confirm(ctx, input.ConsoleOptions{ + DefaultValue: true, + Message: "Would you like to install it?", + }) + if err != nil { + return false, nil + } + + if !shouldInstall { + return false, nil + } + + // Install the extension + console.Message(ctx, fmt.Sprintf("Installing extension '%s'...\n", extension.Id)) + filterOptions := &extensions.FilterOptions{} + _, err = extensionManager.Install(ctx, extension.Id, filterOptions) + if err != nil { + return false, fmt.Errorf("failed to install extension: %w", err) + } + + console.Message(ctx, fmt.Sprintf("Extension '%s' installed successfully!\n", extension.Id)) + return true, nil +} + +// ExecuteWithAutoInstall executes the command and handles auto-installation of extensions for unknown commands. +func ExecuteWithAutoInstall(ctx context.Context, rootContainer *ioc.NestedContainer) error { + // Creating the RootCmd takes care of registering common dependencies in rootContainer + rootCmd := NewRootCmd(false, nil, rootContainer) + + // Continue only if extensions feature is enabled + err := rootContainer.Invoke(func(alphaFeatureManager *alpha.FeatureManager) error { + if !alphaFeatureManager.IsEnabled(extensions.FeatureExtensions) { + return fmt.Errorf("extensions feature is not enabled") + } + return nil + }) + if err != nil { + // Error here means extensions are not enabled or failed to resolve the feature manager + // In either case, we just proceed to normal execution + log.Println("auto-install extensions: ", err) + return rootCmd.ExecuteContext(ctx) + } + + // rootCmd.Find() returns the root command if no subcommand is identified. Cobra checks all the registered commands + // and returns the longest matching command. If no subcommand is found, it returns the root command itself. + // This allows us to determine if a subcommand was provided or not or if the command is unknown. + topCommand, originalArgs, err := rootCmd.Find(os.Args[1:]) + if err != nil { + // If we can't parse the command, just proceed to normal execution + log.Println("Error: parse command. Skipping auto-install:", err) + return rootCmd.ExecuteContext(ctx) + } + if topCommand != rootCmd || len(originalArgs) == 0 { + // known command to be run OR no subcommand provided - skip auto-install + return rootCmd.ExecuteContext(ctx) + } + + // Extract flags that take values from the root command + flagsWithValues := extractFlagsWithValues(rootCmd) + + // Find the first non-flag argument (the actual command) and check for unknown flags + unknownCommand, unknownFlags := findFirstNonFlagArg(originalArgs, flagsWithValues) + + // If we have a command, check if it's a built-in command first + if unknownCommand != "" { + // Check if this is a built-in command first (includes core commands and installed extensions) + if isBuiltInCommand(rootCmd, unknownCommand) { + // This is a built-in command, proceed with normal execution without checking for extensions + return rootCmd.ExecuteContext(ctx) + } + + // If unknown flags were found before a non-built-in command, return an error with helpful guidance + if len(unknownFlags) > 0 { + var console input.Console + if err := rootContainer.Resolve(&console); err != nil { + log.Panic("failed to resolve console for unknown flags error:", err) + } + + flagsList := strings.Join(unknownFlags, ", ") + errorMsg := fmt.Sprintf( + "Unknown flags detected before command '%s': %s\n\n"+ + "If you're trying to run an extension command, the extension name must come BEFORE any flags.\n"+ + "This is because extension-specific flags are not known until the extension is installed.\n\n"+ + "Correct usage:\n"+ + " azd %s %s # Extension name first, then flags\n"+ + " azd %s --help # Get help for the extension\n\n"+ + "If this is not an extension command, please check the flag names for typos.", + unknownCommand, flagsList, + unknownCommand, strings.Join(unknownFlags, " "), + unknownCommand) + + console.Message(ctx, errorMsg) + return fmt.Errorf("unknown flags before command: %s", flagsList) + } + + var extensionManager *extensions.Manager + if err := rootContainer.Resolve(&extensionManager); err != nil { + log.Panic("failed to resolve extension manager for auto-install:", err) + } + + // Get all remaining arguments starting from the command for namespace matching + // This allows checking longer namespaces like "something.demo.foo" from "azd something demo foo" + var argsForMatching []string + for i, arg := range originalArgs { + if !strings.HasPrefix(arg, "-") && arg == unknownCommand { + // Found the command, collect all non-flag arguments from here + for j := i; j < len(originalArgs); j++ { + if !strings.HasPrefix(originalArgs[j], "-") { + argsForMatching = append(argsForMatching, originalArgs[j]) + } + } + break + } + } + + // Check if any commands might match extensions with various namespace lengths + extensionMatches, err := checkForMatchingExtensions(ctx, extensionManager, argsForMatching) + if err != nil { + // Do not fail if we couldn't check for extensions - just proceed to normal execution + log.Println("Error: check for extensions. Skipping auto-install:", err) + return rootCmd.ExecuteContext(ctx) + } + + if len(extensionMatches) > 0 { + var console input.Console + if err := rootContainer.Resolve(&console); err != nil { + log.Panic("failed to resolve console for auto-install:", err) + } + + // Prompt user to choose if multiple extensions match + chosenExtension, err := promptForExtensionChoice(ctx, console, extensionMatches) + if err != nil { + console.Message(ctx, fmt.Sprintf("Error selecting extension: %v", err)) + return rootCmd.ExecuteContext(ctx) + } + + if chosenExtension == nil { + // User cancelled selection, proceed to normal execution + return rootCmd.ExecuteContext(ctx) + } + + // Try to auto-install the chosen extension + installed, installErr := tryAutoInstallExtension(ctx, console, extensionManager, *chosenExtension) + if installErr != nil { + // Error needs to be printed here or else it will be hidden b/c the error printing is handled inside runtime + console.Message(ctx, installErr.Error()) + return installErr + } + + if installed { + // Extension was installed, build command tree and execute + rootCmd := NewRootCmd(false, nil, rootContainer) + return rootCmd.ExecuteContext(ctx) + } + } + } + + // Normal execution path - either no args, no matching extension, or user declined install + return rootCmd.ExecuteContext(ctx) +} diff --git a/cli/azd/cmd/auto_install_builtin_test.go b/cli/azd/cmd/auto_install_builtin_test.go new file mode 100644 index 00000000000..c01c38542a5 --- /dev/null +++ b/cli/azd/cmd/auto_install_builtin_test.go @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "github.com/spf13/cobra" +) + +func TestIsBuiltInCommand(t *testing.T) { + // Create a mock root command with some subcommands + rootCmd := &cobra.Command{ + Use: "azd", + } + + // Add some built-in commands + upCmd := &cobra.Command{ + Use: "up", + } + rootCmd.AddCommand(upCmd) + + initCmd := &cobra.Command{ + Use: "init", + Aliases: []string{"initialize"}, + } + rootCmd.AddCommand(initCmd) + + downCmd := &cobra.Command{ + Use: "down", + } + rootCmd.AddCommand(downCmd) + + tests := []struct { + name string + commandName string + expected bool + }{ + { + name: "built-in command up returns true", + commandName: "up", + expected: true, + }, + { + name: "built-in command init returns true", + commandName: "init", + expected: true, + }, + { + name: "built-in command down returns true", + commandName: "down", + expected: true, + }, + { + name: "command alias initialize returns true", + commandName: "initialize", + expected: true, + }, + { + name: "non-existent command returns false", + commandName: "demo", + expected: false, + }, + { + name: "empty command name returns false", + commandName: "", + expected: false, + }, + { + name: "unknown command returns false", + commandName: "nonexistent", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isBuiltInCommand(rootCmd, tt.commandName) + if result != tt.expected { + t.Errorf("isBuiltInCommand(%q) = %v, expected %v", tt.commandName, result, tt.expected) + } + }) + } +} diff --git a/cli/azd/cmd/auto_install_integration_test.go b/cli/azd/cmd/auto_install_integration_test.go new file mode 100644 index 00000000000..45719e41d8e --- /dev/null +++ b/cli/azd/cmd/auto_install_integration_test.go @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +// TestExecuteWithAutoInstallIntegration tests the integration between +// extractFlagsWithValues and findFirstNonFlagArg in the context of +// the auto-install feature. +func TestExecuteWithAutoInstallIntegration(t *testing.T) { + // Save original args + originalArgs := os.Args + + // Test cases that would have failed before the fix + testCases := []struct { + name string + args []string + expected string + }{ + { + name: "output flag with demo command", + args: []string{"azd", "--output", "json", "demo"}, + expected: "demo", + }, + { + name: "cwd flag with init command", + args: []string{"azd", "--cwd", "/project", "init"}, + expected: "init", + }, + { + name: "mixed flags", + args: []string{"azd", "--debug", "--output", "table", "--no-prompt", "deploy"}, + expected: "deploy", + }, + { + name: "short form flags", + args: []string{"azd", "-o", "json", "-C", "/path", "up"}, + expected: "up", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set test args + os.Args = tc.args + + // Create a test root command to extract flags from + rootCmd := &cobra.Command{Use: "azd"} + + // Add the flags that azd actually uses + rootCmd.PersistentFlags().StringP("output", "o", "", "Output format") + rootCmd.PersistentFlags().StringP("cwd", "C", "", "Working directory") + rootCmd.PersistentFlags().Bool("debug", false, "Debug mode") + rootCmd.PersistentFlags().Bool("no-prompt", false, "No prompting") + + // Extract flags and test our parsing + flagsWithValues := extractFlagsWithValues(rootCmd) + result, _ := findFirstNonFlagArg(os.Args[1:], flagsWithValues) + + assert.Equal(t, tc.expected, result, + "Failed to correctly identify command in args: %v", tc.args) + }) + } + + // Restore original args + os.Args = originalArgs +} diff --git a/cli/azd/cmd/auto_install_multi_namespace_test.go b/cli/azd/cmd/auto_install_multi_namespace_test.go new file mode 100644 index 00000000000..f53b34e4ca6 --- /dev/null +++ b/cli/azd/cmd/auto_install_multi_namespace_test.go @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/extensions" + "github.com/stretchr/testify/assert" +) + +func TestCheckForMatchingExtensionsLogic(t *testing.T) { + // Test the core logic without needing to mock the extension manager + // We'll create a simple function that mimics the matching logic + + testExtensions := []*extensions.ExtensionMetadata{ + { + Id: "extension1", + Namespace: "demo", + DisplayName: "Demo Extension", + Description: "Simple demo extension", + }, + { + Id: "extension2", + Namespace: "vhvb.demo", + DisplayName: "VHVB Demo Extension", + Description: "VHVB namespace demo extension", + }, + { + Id: "extension3", + Namespace: "vhvb.demo.advanced", + DisplayName: "Advanced VHVB Demo", + Description: "Advanced demo with longer namespace", + }, + { + Id: "extension4", + Namespace: "other.namespace", + DisplayName: "Other Extension", + Description: "Different namespace pattern", + }, + } + + // Helper function that mimics checkForMatchingExtensions logic + checkMatches := func( + args []string, availableExtensions []*extensions.ExtensionMetadata) []*extensions.ExtensionMetadata { + if len(args) == 0 { + return nil + } + + var matchingExtensions []*extensions.ExtensionMetadata + + // Generate all possible namespace combinations from the command arguments + for i := 1; i <= len(args); i++ { + candidateNamespace := strings.Join(args[:i], ".") + + // Check if any extension has this exact namespace + for _, ext := range availableExtensions { + if ext.Namespace == candidateNamespace { + matchingExtensions = append(matchingExtensions, ext) + } + } + } + + return matchingExtensions + } + + tests := []struct { + name string + args []string + expectedMatches []string // Extension IDs that should match + }{ + { + name: "single word matches single extension", + args: []string{"demo"}, + expectedMatches: []string{"extension1"}, + }, + { + name: "two words matches nested namespace", + args: []string{"vhvb", "demo"}, + expectedMatches: []string{"extension2"}, + }, + { + name: "three words matches deep namespace", + args: []string{"vhvb", "demo", "advanced"}, + expectedMatches: []string{"extension2", "extension3"}, // Both vhvb.demo and vhvb.demo.advanced should match + }, + { + name: "multiple matches for progressive namespaces", + args: []string{"vhvb", "demo", "advanced", "extra"}, + expectedMatches: []string{"extension2", "extension3"}, // Both vhvb.demo and vhvb.demo.advanced should match + }, + { + name: "no matches for unknown namespace", + args: []string{"unknown", "command"}, + expectedMatches: []string{}, + }, + { + name: "empty args returns no matches", + args: []string{}, + expectedMatches: []string{}, + }, + { + name: "partial namespace without full match", + args: []string{"vhvb"}, + expectedMatches: []string{}, // No extension with namespace "vhvb" exists + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Execute function + matches := checkMatches(tt.args, testExtensions) + + // Verify results + assert.Equal(t, len(tt.expectedMatches), len(matches), "Number of matches should be correct") + + // Check that the right extensions were matched + matchedIds := make([]string, len(matches)) + for i, match := range matches { + matchedIds[i] = match.Id + } + + for _, expectedId := range tt.expectedMatches { + assert.Contains(t, matchedIds, expectedId, "Expected extension %s to be in matches", expectedId) + } + }) + } +} diff --git a/cli/azd/cmd/auto_install_test.go b/cli/azd/cmd/auto_install_test.go new file mode 100644 index 00000000000..2ffc0428b7a --- /dev/null +++ b/cli/azd/cmd/auto_install_test.go @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/extensions" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func TestFindFirstNonFlagArg(t *testing.T) { + // Mock flags that take values for testing + flagsWithValues := map[string]bool{ + "--output": true, + "-o": true, + "--cwd": true, + "-C": true, + "--trace-log-file": true, + "--trace-log-url": true, + "--config": true, // Additional test flag + } + + tests := []struct { + name string + args []string + expected string + }{ + { + name: "first arg is command", + args: []string{"demo", "--flag", "value"}, + expected: "demo", + }, + { + name: "command after boolean flags", + args: []string{"--debug", "--no-prompt", "demo"}, + expected: "demo", + }, + { + name: "only flags", + args: []string{"--help", "--version"}, + expected: "", + }, + { + name: "empty args", + args: []string{}, + expected: "", + }, + { + name: "flags with equals", + args: []string{"--output=json", "demo", "--template=web"}, + expected: "demo", + }, + { + name: "single character boolean flags", + args: []string{"-v", "-h", "up", "--debug"}, + expected: "up", + }, + { + name: "command with output flag value (the original problem)", + args: []string{"--output", "json", "demo", "subcommand"}, + expected: "demo", // Fixed: should be "demo", not "json" + }, + { + name: "command with cwd flag value", + args: []string{"--cwd", "/some/path", "demo"}, + expected: "demo", + }, + { + name: "command with short output flag", + args: []string{"-o", "table", "init"}, + expected: "init", + }, + { + name: "command with short cwd flag", + args: []string{"-C", "/path", "up"}, + expected: "up", + }, + { + name: "mixed flags with values and boolean", + args: []string{"--debug", "--output", "json", "--no-prompt", "deploy"}, + expected: "deploy", + }, + { + name: "no arguments", + args: nil, + expected: "", + }, + { + name: "trace log flags", + args: []string{"--trace-log-file", "debug.log", "monitor"}, + expected: "monitor", + }, + { + name: "complex real world example", + args: []string{"--debug", "--cwd", "/project", "--output", "json", "demo", "--template", "minimal"}, + expected: "demo", + }, + { + name: "test with custom config flag", + args: []string{"--config", "myconfig.yaml", "deploy"}, + expected: "deploy", + }, + { + name: "unknown flag that appears boolean", + args: []string{"--unknown", "command"}, + expected: "command", + }, + { + name: "unknown flag that takes value - PROBLEMATIC CASE", + args: []string{"--unknown-flag", "some-value", "command"}, + expected: "command", // Currently returns "some-value" - this is the problem! + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, _ := findFirstNonFlagArg(tt.args, flagsWithValues) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestFindFirstNonFlagArgWithUnknownFlags(t *testing.T) { + flagsWithValues := map[string]bool{ + "--output": true, + "-o": true, + "--cwd": true, + "-C": true, + } + + tests := []struct { + name string + args []string + expectedCommand string + expectedUnknownFlags []string + }{ + { + name: "no unknown flags", + args: []string{"--output", "json", "deploy"}, + expectedCommand: "deploy", + expectedUnknownFlags: []string{}, + }, + { + name: "single unknown flag before command", + args: []string{"--unknown", "command"}, + expectedCommand: "command", + expectedUnknownFlags: []string{"--unknown"}, + }, + { + name: "unknown flag that takes value", + args: []string{"--unknown-flag", "some-value", "command"}, + expectedCommand: "command", + expectedUnknownFlags: []string{"--unknown-flag"}, + }, + { + name: "multiple unknown flags", + args: []string{"--flag1", "--flag2", "value", "command"}, + expectedCommand: "command", + expectedUnknownFlags: []string{"--flag1", "--flag2"}, + }, + { + name: "mixed known and unknown flags", + args: []string{"--output", "json", "--unknown", "deploy"}, + expectedCommand: "deploy", + expectedUnknownFlags: []string{"--unknown"}, + }, + { + name: "unknown flag with equals", + args: []string{"--unknown=value", "command"}, + expectedCommand: "command", + expectedUnknownFlags: []string{"--unknown"}, + }, + { + name: "only unknown flags, no command", + args: []string{"--unknown1", "--unknown2"}, + expectedCommand: "", + expectedUnknownFlags: []string{"--unknown1", "--unknown2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + command, unknownFlags := findFirstNonFlagArg(tt.args, flagsWithValues) + assert.Equal(t, tt.expectedCommand, command) + assert.Equal(t, tt.expectedUnknownFlags, unknownFlags) + }) + } +} + +func TestExtractFlagsWithValues(t *testing.T) { + // Create a test command with various flag types + cmd := &cobra.Command{ + Use: "test", + } + + // Add flags that take values + cmd.Flags().StringP("output", "o", "", "Output format") + cmd.PersistentFlags().StringP("cwd", "C", "", "Working directory") + cmd.Flags().String("config", "", "Config file") + + // Add boolean flags (should not be included) + cmd.Flags().Bool("debug", false, "Debug mode") + cmd.PersistentFlags().Bool("no-prompt", false, "No prompting") + + // Add flags with other value types + cmd.Flags().Int("port", 8080, "Port number") + cmd.Flags().StringSlice("tags", []string{}, "Tags") + + // Extract flags + flagsWithValues := extractFlagsWithValues(cmd) + + // Test that flags with values are included + assert.True(t, flagsWithValues["--output"], "Should include --output flag") + assert.True(t, flagsWithValues["-o"], "Should include -o shorthand") + assert.True(t, flagsWithValues["--cwd"], "Should include --cwd persistent flag") + assert.True(t, flagsWithValues["-C"], "Should include -C shorthand") + assert.True(t, flagsWithValues["--config"], "Should include --config flag") + assert.True(t, flagsWithValues["--port"], "Should include --port flag (int type)") + assert.True(t, flagsWithValues["--tags"], "Should include --tags flag (slice type)") + + // Test that boolean flags are NOT included + assert.False(t, flagsWithValues["--debug"], "Should not include boolean --debug flag") + assert.False(t, flagsWithValues["--no-prompt"], "Should not include boolean --no-prompt flag") + + // Test non-existent flags + assert.False(t, flagsWithValues["--nonexistent"], "Should not include non-existent flags") +} + +func TestCheckForMatchingExtension_Unit(t *testing.T) { + // This is a unit test that tests the logic without external dependencies + // We'll create a mock-like test by testing the namespace matching logic directly + + testCases := []struct { + name string + command string + extensions []*extensions.ExtensionMetadata + expectedMatch bool + expectedExtId string + }{ + { + name: "matches extension by first namespace part", + command: "demo", + extensions: []*extensions.ExtensionMetadata{ + { + Id: "microsoft.azd.demo", + Namespace: "demo.commands", + }, + }, + expectedMatch: true, + expectedExtId: "microsoft.azd.demo", + }, + { + name: "no match for command", + command: "nonexistent", + extensions: []*extensions.ExtensionMetadata{ + { + Id: "microsoft.azd.demo", + Namespace: "demo.commands", + }, + }, + expectedMatch: false, + }, + { + name: "matches complex namespace", + command: "complex", + extensions: []*extensions.ExtensionMetadata{ + { + Id: "microsoft.azd.complex", + Namespace: "complex.deep.namespace.structure", + }, + }, + expectedMatch: true, + expectedExtId: "microsoft.azd.complex", + }, + { + name: "multiple extensions, finds correct match", + command: "x", + extensions: []*extensions.ExtensionMetadata{ + { + Id: "microsoft.azd.demo", + Namespace: "demo.commands", + }, + { + Id: "microsoft.azd.x", + Namespace: "x.tools", + }, + { + Id: "microsoft.azd.other", + Namespace: "other.namespace", + }, + }, + expectedMatch: true, + expectedExtId: "microsoft.azd.x", + }, + { + name: "empty extensions list", + command: "demo", + extensions: []*extensions.ExtensionMetadata{}, + expectedMatch: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test the namespace matching logic directly + var foundExtension *extensions.ExtensionMetadata + for _, ext := range tc.extensions { + namespaceParts := strings.Split(ext.Namespace, ".") + if len(namespaceParts) > 0 && namespaceParts[0] == tc.command { + foundExtension = ext + break + } + } + + if tc.expectedMatch { + assert.NotNil(t, foundExtension, "Expected to find matching extension") + if foundExtension != nil { + assert.Equal(t, tc.expectedExtId, foundExtension.Id) + } + } else { + assert.Nil(t, foundExtension, "Expected no matching extension") + } + }) + } +} diff --git a/cli/azd/main.go b/cli/azd/main.go index 0be81691ef6..8114c7c4f40 100644 --- a/cli/azd/main.go +++ b/cli/azd/main.go @@ -62,7 +62,9 @@ func main() { rootContainer := ioc.NewNestedContainer(nil) ioc.RegisterInstance(rootContainer, ctx) - cmdErr := cmd.NewRootCmd(false, nil, rootContainer).ExecuteContext(ctx) + + // Execute command with auto-installation support for extensions + cmdErr := cmd.ExecuteWithAutoInstall(ctx, rootContainer) oneauth.Shutdown()