Skip to content

Commit

Permalink
feat: allow setting default SSH keys for contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
phm07 committed Jan 10, 2024
1 parent d89a322 commit 79ee38e
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 8 deletions.
1 change: 1 addition & 0 deletions internal/cmd/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func NewCommand(s state.State) *cobra.Command {
newUseCommand(s),
newDeleteCommand(s),
newListCommand(s),
newSSHKeyCommand(s),
)
return cmd
}
212 changes: 212 additions & 0 deletions internal/cmd/context/ssh_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package context

import (
"fmt"
"os"
"slices"
"strings"
"sync"

"github.com/spf13/cobra"

"github.com/hetznercloud/cli/internal/cmd/cmpl"
"github.com/hetznercloud/cli/internal/hcapi2"
"github.com/hetznercloud/cli/internal/state"
"github.com/hetznercloud/cli/internal/state/config"
"github.com/hetznercloud/hcloud-go/v2/hcloud"
)

// TODO: make *state.Config mockable and write tests

func newSSHKeyCommand(s state.State) *cobra.Command {
cmd := &cobra.Command{
Use: "ssh-key",
Short: "Manage a context's default SSH key",
Args: cobra.NoArgs,
TraverseChildren: true,
DisableFlagsInUseLine: true,
}
cmd.AddCommand(newSSHKeyAddCommand(s))
cmd.AddCommand(newSSHKeyRemoveCommand(s))
cmd.AddCommand(newSSHKeyListCommand(s))
return cmd
}

func newSSHKeyAddCommand(s state.State) *cobra.Command {
cmd := &cobra.Command{
Use: "add SSH-KEY...",
Short: "Add a default SSH key to the context",
TraverseChildren: true,
DisableFlagsInUseLine: true,
RunE: state.Wrap(s, runSSHKeyAdd),
}
cmd.Flags().String("context", "", "Name of the context to add the default SSH key(s) to")
_ = cmd.RegisterFlagCompletionFunc("context", cmpl.SuggestCandidates(s.Config().ContextNames()...))

cmd.Flags().Bool("all", false, "Add all available SSH keys to the context")
_ = cmd.RegisterFlagCompletionFunc("all", cmpl.SuggestCandidates("true", "false"))
return cmd
}

func runSSHKeyAdd(s state.State, cmd *cobra.Command, args []string) error {

ctx, err := getContext(s, cmd)
if err != nil {
return err
}

client := hcapi2.NewClient(hcloud.WithToken(ctx.Token))
keys := args

all, _ := cmd.Flags().GetBool("all")
if all {
allKeys, err := client.SSHKey().All(s)
if err != nil {
return err
}
if len(allKeys) == 0 {
return fmt.Errorf("no SSH keys available")
}
keys = make([]string, len(allKeys))
for i, key := range allKeys {
keys[i] = fmt.Sprintf("%d", key.ID)
}
} else {
if len(keys) == 0 {
return fmt.Errorf("no SSH keys specified")
}
var (
notExist []string
wg = &sync.WaitGroup{}
mu = &sync.Mutex{}
)
wg.Add(len(keys))
for _, key := range keys {
key := key
go func() {
k, _, _ := client.SSHKey().Get(s, key)
if k == nil {
mu.Lock()
notExist = append(notExist, key)
mu.Unlock()
}
wg.Done()
}()
}
wg.Wait()
_, _ = fmt.Fprintf(os.Stderr, "Warning: The given SSH key(s) %s do not exist in context \"%s\"\n", strings.Join(notExist, ", "), ctx.Name)
}

ctx.SSHKeys = append(ctx.SSHKeys, keys...)
// remove duplicates
slices.Sort(ctx.SSHKeys)
ctx.SSHKeys = slices.Compact(ctx.SSHKeys)

if err := s.Config().Write(); err != nil {
return err
}

cmd.Printf("Added SSH key(s) %s to context \"%s\"\n", strings.Join(keys, ", "), ctx.Name)
return nil
}

func newSSHKeyRemoveCommand(s state.State) *cobra.Command {
cmd := &cobra.Command{
Use: "remove SSH-KEY...",
Short: "Remove a default SSH key from the context",
TraverseChildren: true,
DisableFlagsInUseLine: true,
RunE: state.Wrap(s, runSSHKeyRemove),
}
cmd.Flags().String("context", "", "Name of the context to remove the default SSH key(s) from")
_ = cmd.RegisterFlagCompletionFunc("context", cmpl.SuggestCandidates(s.Config().ContextNames()...))

cmd.Flags().Bool("all", false, "Remove all SSH keys from the context")
_ = cmd.RegisterFlagCompletionFunc("all", cmpl.SuggestCandidates("true", "false"))
return cmd
}

func runSSHKeyRemove(s state.State, cmd *cobra.Command, args []string) error {

ctx, err := getContext(s, cmd)
if err != nil {
return err
}

keys := args
origLen := len(ctx.SSHKeys)

all, _ := cmd.Flags().GetBool("all")
if all {
ctx.SSHKeys = nil
} else {
var newKeys []string
for _, key := range ctx.SSHKeys {
if slices.Contains(keys, key) {
continue
}
newKeys = append(newKeys, key)
}
ctx.SSHKeys = newKeys
}

if err := s.Config().Write(); err != nil {
return err
}

removed := origLen - len(ctx.SSHKeys)
cmd.Printf("Removed %d SSH key(s) from context \"%s\"\n", removed, ctx.Name)
return nil
}

func newSSHKeyListCommand(s state.State) *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "List all default SSH keys of the context",
TraverseChildren: true,
DisableFlagsInUseLine: true,
RunE: state.Wrap(s, runSSHKeyList),
}
cmd.Flags().String("context", "", "Name of the context to list the default SSH keys from")
_ = cmd.RegisterFlagCompletionFunc("context", cmpl.SuggestCandidates(s.Config().ContextNames()...))
return cmd
}

func runSSHKeyList(s state.State, cmd *cobra.Command, _ []string) error {

ctx, err := getContext(s, cmd)
if err != nil {
return err
}

if len(ctx.SSHKeys) == 0 {
cmd.Printf("No SSH keys in context \"%s\"\n", ctx.Name)
return nil
}

cmd.Printf("SSH keys in context \"%s\":\n", ctx.Name)
for _, key := range ctx.SSHKeys {
cmd.Printf(" - %s\n", key)
}
return nil
}

func getContext(s state.State, cmd *cobra.Command) (*config.Context, error) {

var ctx *config.Context

ctxName, _ := cmd.Flags().GetString("context")
if ctxName != "" {
ctx = s.Config().ContextByName(ctxName)
if ctx == nil {
return nil, fmt.Errorf("context not found: %v", ctxName)
}
} else {
ctx = s.Config().ActiveContext()
if ctx == nil {
return nil, fmt.Errorf("no active context")
}
}

return ctx, nil
}
7 changes: 7 additions & 0 deletions internal/cmd/server/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ func createOptsFromFlags(
primaryIPv6IDorName, _ := flags.GetString("primary-ipv6")
protection, _ := flags.GetStringSlice("enable-protection")

// check if ssh-key flag was set, otherwise set to defaults
if !flags.Changed("ssh-key") {
if ctx := s.Config().ActiveContext(); ctx != nil {
sshKeys = ctx.SSHKeys
}
}

serverType, _, err := s.Client().ServerType().Get(s, serverTypeName)
if err != nil {
return
Expand Down
9 changes: 9 additions & 0 deletions internal/cmd/server/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ func TestCreate(t *testing.T) {

fx.ExpectEnsureToken()

fx.Config.EXPECT().
ActiveContext().
Return(nil)
fx.Client.ServerTypeClient.EXPECT().
Get(gomock.Any(), "cx11").
Return(&hcloud.ServerType{Architecture: hcloud.ArchitectureX86}, nil, nil)
Expand Down Expand Up @@ -134,6 +137,9 @@ func TestCreateJSON(t *testing.T) {
Status: hcloud.ServerStatusRunning,
}

fx.Config.EXPECT().
ActiveContext().
Return(nil)
fx.Client.ServerTypeClient.EXPECT().
Get(gomock.Any(), "cx11").
Return(&hcloud.ServerType{Architecture: hcloud.ArchitectureX86}, nil, nil)
Expand Down Expand Up @@ -176,6 +182,9 @@ func TestCreateProtectionBackup(t *testing.T) {

fx.ExpectEnsureToken()

fx.Config.EXPECT().
ActiveContext().
Return(nil)
fx.Client.ServerTypeClient.EXPECT().
Get(gomock.Any(), "cx11").
Return(&hcloud.ServerType{Architecture: hcloud.ArchitectureX86}, nil, nil)
Expand Down
8 changes: 8 additions & 0 deletions internal/cmd/server/enable_rescue.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ var EnableRescueCmd = base.Cmd{
opts.Type = hcloud.ServerRescueType(rescueType)

sshKeys, _ := cmd.Flags().GetStringSlice("ssh-key")

// check if ssh-key flag was set, otherwise set to defaults
if !cmd.Flags().Changed("ssh-key") {
if ctx := s.Config().ActiveContext(); ctx != nil {
sshKeys = ctx.SSHKeys
}
}

for _, sshKeyIDOrName := range sshKeys {
sshKey, _, err := s.Client().SSHKey().Get(s, sshKeyIDOrName)
if err != nil {
Expand Down
20 changes: 12 additions & 8 deletions internal/state/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ type Config interface {
}

type Context struct {
Name string
Token string
Name string
Token string
SSHKeys []string
}

type config struct {
Expand Down Expand Up @@ -81,8 +82,9 @@ type rawConfig struct {
}

type rawConfigContext struct {
Name string `toml:"name"`
Token string `toml:"token"`
Name string `toml:"name"`
Token string `toml:"token"`
SSHKeys []string `toml:"ssh_keys"`
}

func (cfg *config) Marshal() ([]byte, error) {
Expand All @@ -93,8 +95,9 @@ func (cfg *config) Marshal() ([]byte, error) {
}
for _, context := range cfg.contexts {
raw.contexts = append(raw.contexts, rawConfigContext{
Name: context.Name,
Token: context.Token,
Name: context.Name,
Token: context.Token,
SSHKeys: context.SSHKeys,
})
}
return toml.Marshal(raw)
Expand Down Expand Up @@ -160,8 +163,9 @@ func (cfg *config) unmarshal(data []byte) error {
}
for _, rawContext := range raw.contexts {
cfg.contexts = append(cfg.contexts, &Context{
Name: rawContext.Name,
Token: rawContext.Token,
Name: rawContext.Name,
Token: rawContext.Token,
SSHKeys: rawContext.SSHKeys,
})
}
if raw.activeContext != "" {
Expand Down

0 comments on commit 79ee38e

Please sign in to comment.