From 081ecd1d44244d325643a98186e4aaad10f3c9d1 Mon Sep 17 00:00:00 2001 From: pauhull <22707808+pauhull@users.noreply.github.com> Date: Wed, 10 Jan 2024 13:24:23 +0100 Subject: [PATCH] feat: allow setting default SSH keys for contexts --- internal/cmd/context/context.go | 1 + internal/cmd/context/ssh_key.go | 211 ++++++++++++++++++++++ internal/cmd/context/ssh_key_test.go | 261 +++++++++++++++++++++++++++ internal/cmd/server/create.go | 7 + internal/cmd/server/create_test.go | 9 + internal/cmd/server/enable_rescue.go | 8 + internal/state/config/config.go | 15 +- 7 files changed, 506 insertions(+), 6 deletions(-) create mode 100644 internal/cmd/context/ssh_key.go create mode 100644 internal/cmd/context/ssh_key_test.go diff --git a/internal/cmd/context/context.go b/internal/cmd/context/context.go index b7138297..3b8c99f3 100644 --- a/internal/cmd/context/context.go +++ b/internal/cmd/context/context.go @@ -20,6 +20,7 @@ func NewCommand(s state.State) *cobra.Command { newUseCommand(s), newDeleteCommand(s), newListCommand(s), + newSSHKeyCommand(s), ) return cmd } diff --git a/internal/cmd/context/ssh_key.go b/internal/cmd/context/ssh_key.go new file mode 100644 index 00000000..4346040c --- /dev/null +++ b/internal/cmd/context/ssh_key.go @@ -0,0 +1,211 @@ +package context + +import ( + "fmt" + "os" + "slices" + "strings" + "sync" + + "github.com/spf13/cobra" + + "github.com/hetznercloud/cli/internal/cmd/cmpl" + "github.com/hetznercloud/cli/internal/state" + "github.com/hetznercloud/cli/internal/state/config" + "github.com/hetznercloud/hcloud-go/v2/hcloud" +) + +func newSSHKeyCommand(s state.State) *cobra.Command { + cmd := &cobra.Command{ + Use: "ssh-key", + Short: "Manage a context's default SSH key", + Args: cobra.NoArgs, + TraverseChildren: true, + DisableFlagsInUseLine: true, + } + cmd.AddCommand(newSSHKeyAddCommand(s)) + cmd.AddCommand(newSSHKeyRemoveCommand(s)) + cmd.AddCommand(newSSHKeyListCommand(s)) + return cmd +} + +func newSSHKeyAddCommand(s state.State) *cobra.Command { + cmd := &cobra.Command{ + Use: "add SSH-KEY...", + Short: "Add a default SSH key to the context", + TraverseChildren: true, + DisableFlagsInUseLine: true, + RunE: state.Wrap(s, runSSHKeyAdd), + } + cmd.Flags().String("context", "", "Name of the context to add the default SSH key(s) to") + _ = cmd.RegisterFlagCompletionFunc("context", cmpl.SuggestCandidates(config.ContextNames(s.Config())...)) + + cmd.Flags().Bool("all", false, "Add all available SSH keys to the context") + _ = cmd.RegisterFlagCompletionFunc("all", cmpl.SuggestCandidates("true", "false")) + return cmd +} + +func runSSHKeyAdd(s state.State, cmd *cobra.Command, args []string) error { + + ctx, err := getContext(s, cmd) + if err != nil { + return err + } + + s.Client().WithOpts(hcloud.WithToken(ctx.Token)) + keys := args + + all, _ := cmd.Flags().GetBool("all") + if all { + allKeys, err := s.Client().SSHKey().All(s) + if err != nil { + return err + } + if len(allKeys) == 0 { + return fmt.Errorf("no SSH keys available") + } + keys = make([]string, len(allKeys)) + for i, key := range allKeys { + keys[i] = fmt.Sprintf("%d", key.ID) + } + } else { + if len(keys) == 0 { + return fmt.Errorf("no SSH keys specified") + } + var ( + notExist []string + wg = &sync.WaitGroup{} + mu = &sync.Mutex{} + ) + wg.Add(len(keys)) + for _, key := range keys { + key := key + go func() { + k, _, _ := s.Client().SSHKey().Get(s, key) + if k == nil { + mu.Lock() + notExist = append(notExist, key) + mu.Unlock() + } + wg.Done() + }() + } + wg.Wait() + if len(notExist) > 0 { + _, _ = fmt.Fprintf(os.Stderr, "Warning: The given SSH key(s) %s do not exist in context \"%s\"\n", strings.Join(notExist, ", "), ctx.Name) + } + } + + ctx.SSHKeys = append(ctx.SSHKeys, keys...) + // remove duplicates + slices.Sort(ctx.SSHKeys) + ctx.SSHKeys = slices.Compact(ctx.SSHKeys) + + if err := s.Config().Write(); err != nil { + return err + } + + cmd.Printf("Added SSH key(s) %s to context \"%s\"\n", strings.Join(keys, ", "), ctx.Name) + return nil +} + +func newSSHKeyRemoveCommand(s state.State) *cobra.Command { + cmd := &cobra.Command{ + Use: "remove SSH-KEY...", + Short: "Remove a default SSH key from the context", + TraverseChildren: true, + DisableFlagsInUseLine: true, + RunE: state.Wrap(s, runSSHKeyRemove), + } + cmd.Flags().String("context", "", "Name of the context to remove the default SSH key(s) from") + _ = cmd.RegisterFlagCompletionFunc("context", cmpl.SuggestCandidates(config.ContextNames(s.Config())...)) + + cmd.Flags().Bool("all", false, "Remove all SSH keys from the context") + _ = cmd.RegisterFlagCompletionFunc("all", cmpl.SuggestCandidates("true", "false")) + return cmd +} + +func runSSHKeyRemove(s state.State, cmd *cobra.Command, args []string) error { + + ctx, err := getContext(s, cmd) + if err != nil { + return err + } + + keys := args + origLen := len(ctx.SSHKeys) + + all, _ := cmd.Flags().GetBool("all") + if all { + ctx.SSHKeys = nil + } else { + var newKeys []string + for _, key := range ctx.SSHKeys { + if slices.Contains(keys, key) { + continue + } + newKeys = append(newKeys, key) + } + ctx.SSHKeys = newKeys + } + + if err := s.Config().Write(); err != nil { + return err + } + + removed := origLen - len(ctx.SSHKeys) + cmd.Printf("Removed %d SSH key(s) from context \"%s\"\n", removed, ctx.Name) + return nil +} + +func newSSHKeyListCommand(s state.State) *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List all default SSH keys of the context", + TraverseChildren: true, + DisableFlagsInUseLine: true, + RunE: state.Wrap(s, runSSHKeyList), + } + cmd.Flags().String("context", "", "Name of the context to list the default SSH keys from") + _ = cmd.RegisterFlagCompletionFunc("context", cmpl.SuggestCandidates(config.ContextNames(s.Config())...)) + return cmd +} + +func runSSHKeyList(s state.State, cmd *cobra.Command, _ []string) error { + + ctx, err := getContext(s, cmd) + if err != nil { + return err + } + + if len(ctx.SSHKeys) == 0 { + cmd.Printf("No SSH keys in context \"%s\"\n", ctx.Name) + return nil + } + + cmd.Printf("SSH keys in context \"%s\":\n", ctx.Name) + for _, key := range ctx.SSHKeys { + cmd.Printf(" - %s\n", key) + } + return nil +} + +func getContext(s state.State, cmd *cobra.Command) (*config.Context, error) { + + var ctx *config.Context + + ctxName, _ := cmd.Flags().GetString("context") + if ctxName != "" { + ctx = config.ContextByName(s.Config(), ctxName) + if ctx == nil { + return nil, fmt.Errorf("context not found: %v", ctxName) + } + } else { + ctx = s.Config().ActiveContext() + if ctx == nil { + return nil, fmt.Errorf("no active context") + } + } + + return ctx, nil +} diff --git a/internal/cmd/context/ssh_key_test.go b/internal/cmd/context/ssh_key_test.go new file mode 100644 index 00000000..336d75ef --- /dev/null +++ b/internal/cmd/context/ssh_key_test.go @@ -0,0 +1,261 @@ +package context_test + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/hetznercloud/cli/internal/cli" + "github.com/hetznercloud/cli/internal/cmd/context" + "github.com/hetznercloud/cli/internal/state/config" + "github.com/hetznercloud/cli/internal/testutil" + "github.com/hetznercloud/hcloud-go/v2/hcloud" +) + +func TestSSHKeyAdd(t *testing.T) { + fx := testutil.NewFixture(t) + defer fx.Finish() + + rootCmd := cli.NewRootCommand(fx.State()) + + // needed because subcommands fetch a list of context names for completions + fx.Config.EXPECT(). + Contexts(). + AnyTimes(). + Return(nil) + + rootCmd.AddCommand(context.NewCommand(fx.State())) + + activeContext := &config.Context{ + Name: "test", + SSHKeys: []string{"existing_key"}, + } + + fx.Config.EXPECT(). + ActiveContext(). + Return(activeContext) + fx.Client.SSHKeyClient.EXPECT(). + Get(gomock.Any(), "new_key"). + Return(&hcloud.SSHKey{}, nil, nil) + fx.Config.EXPECT(). + Write() + + out, errOut, err := fx.Run(rootCmd, []string{"context", "ssh-key", "add", "new_key"}) + + assert.NoError(t, err) + assert.Equal(t, "Added SSH key(s) new_key to context \"test\"\n", out) + assert.Empty(t, errOut) + + assert.Equal(t, []string{"existing_key", "new_key"}, activeContext.SSHKeys) +} + +func TestSSHKeyAddAll(t *testing.T) { + fx := testutil.NewFixture(t) + defer fx.Finish() + + rootCmd := cli.NewRootCommand(fx.State()) + + // needed because subcommands fetch a list of context names for completions + fx.Config.EXPECT(). + Contexts(). + AnyTimes(). + Return(nil) + + rootCmd.AddCommand(context.NewCommand(fx.State())) + + activeContext := &config.Context{ + Name: "test", + SSHKeys: []string{"existing_key"}, + } + + fx.Config.EXPECT(). + ActiveContext(). + Return(activeContext) + fx.Client.SSHKeyClient.EXPECT(). + All(gomock.Any()). + Return([]*hcloud.SSHKey{{ID: 42, Name: "foo"}, {ID: 1337, Name: "bar"}}, nil) + fx.Config.EXPECT(). + Write() + + out, errOut, err := fx.Run(rootCmd, []string{"context", "ssh-key", "add", "--all"}) + + assert.NoError(t, err) + assert.Equal(t, "Added SSH key(s) 42, 1337 to context \"test\"\n", out) + assert.Empty(t, errOut) + + assert.Equal(t, []string{"1337", "42", "existing_key"}, activeContext.SSHKeys) +} + +func TestSSHKeyAddContext(t *testing.T) { + fx := testutil.NewFixture(t) + defer fx.Finish() + + rootCmd := cli.NewRootCommand(fx.State()) + + contexts := []*config.Context{ + { + Name: "test", + }, + { + Name: "test2", + SSHKeys: []string{"existing_key"}, + }, + } + + // needed because subcommands fetch a list of context names for completions + fx.Config.EXPECT(). + Contexts(). + AnyTimes(). + Return(contexts) + + rootCmd.AddCommand(context.NewCommand(fx.State())) + + fx.Client.SSHKeyClient.EXPECT(). + Get(gomock.Any(), "new_key"). + Return(&hcloud.SSHKey{}, nil, nil) + fx.Config.EXPECT(). + Write() + + out, errOut, err := fx.Run(rootCmd, []string{"context", "ssh-key", "add", "new_key", "--context", "test2"}) + + assert.NoError(t, err) + assert.Equal(t, "Added SSH key(s) new_key to context \"test2\"\n", out) + assert.Empty(t, errOut) + + assert.Equal(t, []string{"existing_key", "new_key"}, contexts[1].SSHKeys) +} + +func TestSSHKeyRemove(t *testing.T) { + fx := testutil.NewFixture(t) + defer fx.Finish() + + rootCmd := cli.NewRootCommand(fx.State()) + + // needed because subcommands fetch a list of context names for completions + fx.Config.EXPECT(). + Contexts(). + AnyTimes(). + Return(nil) + + rootCmd.AddCommand(context.NewCommand(fx.State())) + + activeContext := &config.Context{ + Name: "test", + SSHKeys: []string{"remove_me", "dont_remove_me"}, + } + + fx.Config.EXPECT(). + ActiveContext(). + Return(activeContext) + fx.Config.EXPECT(). + Write() + + out, errOut, err := fx.Run(rootCmd, []string{"context", "ssh-key", "remove", "remove_me"}) + + assert.NoError(t, err) + assert.Equal(t, "Removed 1 SSH key(s) from context \"test\"\n", out) + assert.Empty(t, errOut) + + assert.Equal(t, []string{"dont_remove_me"}, activeContext.SSHKeys) +} + +func TestSSHKeyRemoveContext(t *testing.T) { + fx := testutil.NewFixture(t) + defer fx.Finish() + + rootCmd := cli.NewRootCommand(fx.State()) + + contexts := []*config.Context{ + { + Name: "test", + }, + { + Name: "test2", + SSHKeys: []string{"remove_me", "dont_remove_me"}, + }, + } + + // needed because subcommands fetch a list of context names for completions + fx.Config.EXPECT(). + Contexts(). + AnyTimes(). + Return(contexts) + + rootCmd.AddCommand(context.NewCommand(fx.State())) + + fx.Config.EXPECT(). + Write() + + out, errOut, err := fx.Run(rootCmd, []string{"context", "ssh-key", "remove", "remove_me", "--context", "test2"}) + + assert.NoError(t, err) + assert.Equal(t, "Removed 1 SSH key(s) from context \"test2\"\n", out) + assert.Empty(t, errOut) + + assert.Equal(t, []string{"dont_remove_me"}, contexts[1].SSHKeys) +} + +func TestSSHKeyRemoveAll(t *testing.T) { + fx := testutil.NewFixture(t) + defer fx.Finish() + + rootCmd := cli.NewRootCommand(fx.State()) + + // needed because subcommands fetch a list of context names for completions + fx.Config.EXPECT(). + Contexts(). + AnyTimes(). + Return(nil) + + rootCmd.AddCommand(context.NewCommand(fx.State())) + + activeContext := &config.Context{ + Name: "test", + SSHKeys: []string{"remove_me", "remove_me_too"}, + } + + fx.Config.EXPECT(). + ActiveContext(). + Return(activeContext) + fx.Config.EXPECT(). + Write() + + out, errOut, err := fx.Run(rootCmd, []string{"context", "ssh-key", "remove", "--all"}) + + assert.NoError(t, err) + assert.Equal(t, "Removed 2 SSH key(s) from context \"test\"\n", out) + assert.Empty(t, errOut) + + assert.Empty(t, activeContext.SSHKeys) +} + +func TestSSHKeyList(t *testing.T) { + fx := testutil.NewFixture(t) + defer fx.Finish() + + rootCmd := cli.NewRootCommand(fx.State()) + + // needed because subcommands fetch a list of context names for completions + fx.Config.EXPECT(). + Contexts(). + AnyTimes(). + Return(nil) + + rootCmd.AddCommand(context.NewCommand(fx.State())) + + activeContext := &config.Context{ + Name: "test", + SSHKeys: []string{"foo", "bar", "baz"}, + } + + fx.Config.EXPECT(). + ActiveContext(). + Return(activeContext) + + out, errOut, err := fx.Run(rootCmd, []string{"context", "ssh-key", "list"}) + + assert.NoError(t, err) + assert.Equal(t, "SSH keys in context \"test\":\n - foo\n - bar\n - baz\n", out) + assert.Empty(t, errOut) +} diff --git a/internal/cmd/server/create.go b/internal/cmd/server/create.go index 3573314d..038b6268 100644 --- a/internal/cmd/server/create.go +++ b/internal/cmd/server/create.go @@ -261,6 +261,13 @@ func createOptsFromFlags( primaryIPv6IDorName, _ := flags.GetString("primary-ipv6") protection, _ := flags.GetStringSlice("enable-protection") + // check if ssh-key flag was set, otherwise set to defaults + if !flags.Changed("ssh-key") { + if ctx := s.Config().ActiveContext(); ctx != nil { + sshKeys = ctx.SSHKeys + } + } + serverType, _, err := s.Client().ServerType().Get(s, serverTypeName) if err != nil { return diff --git a/internal/cmd/server/create_test.go b/internal/cmd/server/create_test.go index 601aa33d..eeeea781 100644 --- a/internal/cmd/server/create_test.go +++ b/internal/cmd/server/create_test.go @@ -26,6 +26,9 @@ func TestCreate(t *testing.T) { fx.ExpectEnsureToken() + fx.Config.EXPECT(). + ActiveContext(). + Return(nil) fx.Client.ServerTypeClient.EXPECT(). Get(gomock.Any(), "cx11"). Return(&hcloud.ServerType{Architecture: hcloud.ArchitectureX86}, nil, nil) @@ -134,6 +137,9 @@ func TestCreateJSON(t *testing.T) { Status: hcloud.ServerStatusRunning, } + fx.Config.EXPECT(). + ActiveContext(). + Return(nil) fx.Client.ServerTypeClient.EXPECT(). Get(gomock.Any(), "cx11"). Return(&hcloud.ServerType{Architecture: hcloud.ArchitectureX86}, nil, nil) @@ -176,6 +182,9 @@ func TestCreateProtectionBackup(t *testing.T) { fx.ExpectEnsureToken() + fx.Config.EXPECT(). + ActiveContext(). + Return(nil) fx.Client.ServerTypeClient.EXPECT(). Get(gomock.Any(), "cx11"). Return(&hcloud.ServerType{Architecture: hcloud.ArchitectureX86}, nil, nil) diff --git a/internal/cmd/server/enable_rescue.go b/internal/cmd/server/enable_rescue.go index d75cfaf7..9ff91a02 100644 --- a/internal/cmd/server/enable_rescue.go +++ b/internal/cmd/server/enable_rescue.go @@ -46,6 +46,14 @@ var EnableRescueCmd = base.Cmd{ opts.Type = hcloud.ServerRescueType(rescueType) sshKeys, _ := cmd.Flags().GetStringSlice("ssh-key") + + // check if ssh-key flag was set, otherwise set to defaults + if !cmd.Flags().Changed("ssh-key") { + if ctx := s.Config().ActiveContext(); ctx != nil { + sshKeys = ctx.SSHKeys + } + } + for _, sshKeyIDOrName := range sshKeys { sshKey, _, err := s.Client().SSHKey().Get(s, sshKeyIDOrName) if err != nil { diff --git a/internal/state/config/config.go b/internal/state/config/config.go index 42594f73..4b6984db 100644 --- a/internal/state/config/config.go +++ b/internal/state/config/config.go @@ -22,8 +22,9 @@ type Config interface { } type Context struct { - Name string - Token string + Name string + Token string + SSHKeys []string } type config struct { @@ -128,8 +129,9 @@ type rawConfig struct { } type rawConfigContext struct { - Name string `toml:"name"` - Token string `toml:"token"` + Name string `toml:"name"` + Token string `toml:"token"` + SSHKeys []string `toml:"ssh_keys"` } func (cfg *config) marshal() ([]byte, error) { @@ -153,8 +155,9 @@ func (cfg *config) unmarshal(data []byte) error { } for _, rawContext := range raw.contexts { cfg.contexts = append(cfg.contexts, &Context{ - Name: rawContext.Name, - Token: rawContext.Token, + Name: rawContext.Name, + Token: rawContext.Token, + SSHKeys: rawContext.SSHKeys, }) } if raw.activeContext != "" {