diff --git a/flytectl/cmd/config/config.go b/flytectl/cmd/config/config.go index 4296d28b3b8..1de38d02e74 100644 --- a/flytectl/cmd/config/config.go +++ b/flytectl/cmd/config/config.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytectl/pkg/printer" @@ -11,7 +12,8 @@ import ( var ( defaultConfig = &Config{ - Output: printer.OutputFormatTABLE.String(), + Output: printer.OutputFormatTABLE.String(), + TokenCacheType: cache.TokenCacheTypeKeyring, } section = config.MustRegisterSection("root", defaultConfig) @@ -19,10 +21,11 @@ var ( // Config hold configuration for flytectl flag type Config struct { - Project string `json:"project" pflag:",Specifies the project to work on."` - Domain string `json:"domain" pflag:",Specifies the domain to work on."` - Output string `json:"output" pflag:",Specifies the output type."` - Interactive bool `json:"interactive" pflag:",Set this to trigger bubbletea interface."` + Project string `json:"project" pflag:",Specifies the project to work on."` + Domain string `json:"domain" pflag:",Specifies the domain to work on."` + Output string `json:"output" pflag:",Specifies the output type."` + Interactive bool `json:"interactive" pflag:",Set this to trigger bubbletea interface."` + TokenCacheType cache.TokenCacheType `json:"token_cache_type" pflag:",Specifices the token cache type to use for fetching / saving auth tokens."` } // OutputFormat will return output format diff --git a/flytectl/cmd/core/cmd.go b/flytectl/cmd/core/cmd.go index 989f4b7ebbb..5fe817f62c4 100644 --- a/flytectl/cmd/core/cmd.go +++ b/flytectl/cmd/core/cmd.go @@ -11,6 +11,7 @@ import ( "github.com/flyteorg/flyte/flytectl/cmd/config" "github.com/flyteorg/flyte/flytectl/pkg/pkce" "github.com/flyteorg/flyte/flyteidl/clients/go/admin" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -70,13 +71,24 @@ func generateCommandFunc(cmdEntry CommandEntry) func(cmd *cobra.Command, args [] return cmdEntry.CmdFunc(ctx, args, CommandContext{}) } + var tokenCache cache.TokenCache + svcUser := fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser) + switch config.GetConfig().TokenCacheType { + case cache.TokenCacheTypeFilesystem: + tokenCache = pkce.NewtokenCacheFilesystemProvider(svcUser) + case cache.TokenCacheTypeKeyring: + fallthrough + default: + tokenCache = pkce.TokenCacheKeyringProvider{ + ServiceUser: svcUser, + ServiceName: pkce.KeyRingServiceName, + } + } + cmdCtx := NewCommandContextNoClient(cmd.OutOrStdout()) if !cmdEntry.DisableFlyteClient { clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)). - WithTokenCache(pkce.TokenCacheKeyringProvider{ - ServiceUser: fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser), - ServiceName: pkce.KeyRingServiceName, - }).Build(ctx) + WithTokenCache(tokenCache).Build(ctx) if err != nil { return err } diff --git a/flytectl/cmd/root.go b/flytectl/cmd/root.go index 112fa4074c5..c81e4efbd3a 100644 --- a/flytectl/cmd/root.go +++ b/flytectl/cmd/root.go @@ -20,6 +20,7 @@ import ( "github.com/flyteorg/flyte/flytectl/cmd/version" f "github.com/flyteorg/flyte/flytectl/pkg/filesystemutils" "github.com/flyteorg/flyte/flytectl/pkg/printer" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache" stdConfig "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/config/viper" @@ -57,6 +58,7 @@ func newRootCmd() *cobra.Command { rootCmd.PersistentFlags().StringVarP(&(config.GetConfig().Domain), "domain", "d", "", "Specifies the Flyte project's domain.") rootCmd.PersistentFlags().StringVarP(&(config.GetConfig().Output), "output", "o", printer.OutputFormatTABLE.String(), fmt.Sprintf("Specifies the output type - supported formats %s. NOTE: dot, doturl are only supported for Workflow", printer.OutputFormats())) rootCmd.PersistentFlags().BoolVarP(&(config.GetConfig().Interactive), "interactive", "i", false, "Set this flag to use an interactive CLI") + rootCmd.PersistentFlags().Var(&(config.GetConfig().TokenCacheType), "token-cache-type", fmt.Sprintf("Type of token cache to use (available options are %s)", cache.AllTokenCacheTypes)) rootCmd.AddCommand(get.CreateGetCommand()) compileCmd := compile.CreateCompileCommand() diff --git a/flytectl/pkg/pkce/token_cache_filesystem.go b/flytectl/pkg/pkce/token_cache_filesystem.go new file mode 100644 index 00000000000..89249a40308 --- /dev/null +++ b/flytectl/pkg/pkce/token_cache_filesystem.go @@ -0,0 +1,137 @@ +package pkce + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "golang.org/x/oauth2" + + b64 "encoding/base64" + "encoding/json" + + f "github.com/flyteorg/flyte/flytectl/pkg/filesystemutils" +) + +// tokenCacheFilesystemProvider wraps the logic to save and retrieve tokens from the fs. +type tokenCacheFilesystemProvider struct { + ServiceUser string + + // credentialsFile is the path to the file where the credentials are stored. This is + // typically $HOME/.flyte/credentials.json but embedded as a private field for tests. + credentialsFile string +} + +func NewtokenCacheFilesystemProvider(serviceUser string) *tokenCacheFilesystemProvider { + return &tokenCacheFilesystemProvider{ + ServiceUser: serviceUser, + credentialsFile: f.FilePathJoin(f.UserHomeDir(), ".flyte", "credentials.json"), + } +} + +type credentials map[string]*oauth2.Token + +func (c credentials) MarshalJSON() ([]byte, error) { + m := make(map[string]string) + for k, v := range c { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + m[k] = b64.StdEncoding.EncodeToString(b) + } + return json.Marshal(m) +} + +func (c credentials) UnmarshalJSON(b []byte) error { + m := make(map[string]string) + if err := json.Unmarshal(b, &m); err != nil { + return err + } + for k, v := range m { + s, err := b64.StdEncoding.DecodeString(v) + if err != nil { + return err + } + tk := &oauth2.Token{} + if err = json.Unmarshal(s, tk); err != nil { + return err + } + c[k] = tk + } + return nil +} + +func (t tokenCacheFilesystemProvider) SaveToken(token *oauth2.Token) error { + if token.AccessToken == "" { + return fmt.Errorf("cannot save empty token with expiration %v", token.Expiry) + } + + dir := filepath.Dir(t.credentialsFile) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("creating base directory (%s) for credentials: %s", dir, err.Error()) + } + + creds, err := t.getExistingCredentials() + if err != nil { + return err + } + creds[t.ServiceUser] = token + + tmp, err := os.CreateTemp("", "flytectl") + if err != nil { + return fmt.Errorf("creating tmp file for credentials update: %s", err.Error()) + } + defer os.Remove(tmp.Name()) + + b, err := json.Marshal(creds) + if err != nil { + return fmt.Errorf("marshalling credentials: %s", err.Error()) + } + if _, err := tmp.Write(b); err != nil { + return fmt.Errorf("writing updated credentials to tmp file: %s", err.Error()) + } + + if err = os.Rename(tmp.Name(), t.credentialsFile); err != nil { + return fmt.Errorf("updating credentials via tmp file rename: %s", err.Error()) + } + + return nil +} + +func (t tokenCacheFilesystemProvider) GetToken() (*oauth2.Token, error) { + creds, err := t.getExistingCredentials() + if err != nil { + return nil, err + } + + if token, ok := creds[t.ServiceUser]; ok { + return token, nil + } + + return nil, errors.New("token does not exist") +} + +func (t tokenCacheFilesystemProvider) getExistingCredentials() (credentials, error) { + dir := filepath.Dir(t.credentialsFile) + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, fmt.Errorf("creating base directory (%s) for credentials: %s", dir, err.Error()) + } + + creds := credentials{} + if _, err := os.Stat(t.credentialsFile); errors.Is(err, os.ErrNotExist) { + return creds, nil + } + + b, err := os.ReadFile(t.credentialsFile) + if err != nil { + return nil, fmt.Errorf("reading existing credentials: %s", err.Error()) + } + + if err = json.Unmarshal(b, &creds); err != nil { + return nil, fmt.Errorf("unmarshalling credentials: %s", err.Error()) + } + + return creds, nil +} diff --git a/flytectl/pkg/pkce/token_cache_filesystem_test.go b/flytectl/pkg/pkce/token_cache_filesystem_test.go new file mode 100644 index 00000000000..2d78b0e1030 --- /dev/null +++ b/flytectl/pkg/pkce/token_cache_filesystem_test.go @@ -0,0 +1,96 @@ +package pkce + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestSaveAndGetTokenFS(t *testing.T) { + setup := func(t *testing.T) tokenCacheFilesystemProvider { + t.Helper() + // Everything inside the directory is automatically cleaned up by the test runner. + dir := t.TempDir() + tokenCacheProvider := tokenCacheFilesystemProvider{ + ServiceUser: "testServiceUser", + credentialsFile: filepath.Join(dir, "credentials.json"), + } + return tokenCacheProvider + } + + t.Run("Valid Save/Get Token", func(t *testing.T) { + tokenCacheProvider := setup(t) + + plan, err := os.ReadFile("testdata/token.json") + require.NoError(t, err) + + var tokenData oauth2.Token + err = json.Unmarshal(plan, &tokenData) + require.NoError(t, err) + + err = tokenCacheProvider.SaveToken(&tokenData) + require.NoError(t, err) + + var savedToken *oauth2.Token + savedToken, err = tokenCacheProvider.GetToken() + require.NoError(t, err) + + assert.NotNil(t, savedToken) + assert.Equal(t, tokenData.AccessToken, savedToken.AccessToken) + assert.Equal(t, tokenData.TokenType, savedToken.TokenType) + assert.Equal(t, tokenData.Expiry, savedToken.Expiry) + }) + + t.Run("Empty access token Save", func(t *testing.T) { + tokenCacheProvider := setup(t) + + plan, err := os.ReadFile("testdata/empty_access_token.json") + require.NoError(t, err) + + var tokenData oauth2.Token + err = json.Unmarshal(plan, &tokenData) + require.NoError(t, err) + + err = tokenCacheProvider.SaveToken(&tokenData) + assert.Error(t, err) + }) + + t.Run("Different service name", func(t *testing.T) { + tokenCacheProvider := setup(t) + + plan, err := os.ReadFile("testdata/token.json") + require.NoError(t, err) + + var tokenData oauth2.Token + err = json.Unmarshal(plan, &tokenData) + require.NoError(t, err) + + err = tokenCacheProvider.SaveToken(&tokenData) + require.NoError(t, err) + + tokenCacheProvider2 := setup(t) + + var savedToken *oauth2.Token + savedToken, err = tokenCacheProvider2.GetToken() + assert.Error(t, err) + assert.Nil(t, savedToken) + + err = tokenCacheProvider2.SaveToken(&tokenData) + require.NoError(t, err) + + // new token exists + savedToken, err = tokenCacheProvider2.GetToken() + require.NoError(t, err) + assert.NotNil(t, savedToken) + + // token for different service name still exists + savedToken, err = tokenCacheProvider.GetToken() + require.NoError(t, err) + assert.NotNil(t, savedToken) + }) +} diff --git a/flyteidl/clients/go/admin/cache/token_cache.go b/flyteidl/clients/go/admin/cache/token_cache.go index e4e2b7e17f2..afc48409e9c 100644 --- a/flyteidl/clients/go/admin/cache/token_cache.go +++ b/flyteidl/clients/go/admin/cache/token_cache.go @@ -1,9 +1,52 @@ package cache -import "golang.org/x/oauth2" +import ( + "fmt" + "slices" + "strings" + + "golang.org/x/oauth2" +) //go:generate mockery -all -case=underscore +// TokenCacheType defines the type of token cache implementation. +type TokenCacheType string + +const ( + // TokenCacheTypeKeyring represents the token cache implementation using the OS's keyring. + TokenCacheTypeKeyring TokenCacheType = "keyring" + // TokenCacheTypeInMemory represents the token cache implementation using an in-memory cache. + TokenCacheTypeInMemory = "inmemory" + // TokenCacheTypeFilesystem represents the token cache implementation using the local filesystem. + TokenCacheTypeFilesystem = "filesystem" +) + +var AllTokenCacheTypes = []TokenCacheType{TokenCacheTypeKeyring, TokenCacheTypeInMemory, TokenCacheTypeFilesystem} + +// String implements pflag.Value interface. +func (t *TokenCacheType) String() string { + if t == nil { + return "" + } + return string(*t) +} + +// Set implements pflag.Value interface. +func (t *TokenCacheType) Set(value string) error { + if slices.Contains(AllTokenCacheTypes, TokenCacheType(strings.ToLower(value))) { + *t = TokenCacheType(value) + return nil + } + + return fmt.Errorf("%s is an unrecognized token cache type (supported types %v)", value, AllTokenCacheTypes) +} + +// Type implements pflag.Value interface. +func (t *TokenCacheType) Type() string { + return "token-cache-type" +} + // TokenCache defines the interface needed to cache and retrieve oauth tokens. type TokenCache interface { // SaveToken saves the token securely to cache.