Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for filesystem implementation of token cache #5421

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions flytectl/cmd/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,28 @@
"fmt"
"strings"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flytestdlib/config"

"github.com/flyteorg/flyte/flytectl/pkg/printer"
)

var (
defaultConfig = &Config{
Output: printer.OutputFormatTABLE.String(),
Output: printer.OutputFormatTABLE.String(),
TokenCacheType: cache.TokenCacheTypeKeyring,
}

section = config.MustRegisterSection("root", defaultConfig)
)

// Config hold configuration for flytectl flag
type Config struct {
Project string `json:"project" pflag:",Specifies the project to work on."`
Domain string `json:"domain" pflag:",Specifies the domain to work on."`
Output string `json:"output" pflag:",Specifies the output type."`
Interactive bool `json:"interactive" pflag:",Set this to trigger bubbletea interface."`
Project string `json:"project" pflag:",Specifies the project to work on."`
Domain string `json:"domain" pflag:",Specifies the domain to work on."`
Output string `json:"output" pflag:",Specifies the output type."`
Interactive bool `json:"interactive" pflag:",Set this to trigger bubbletea interface."`
TokenCacheType cache.TokenCacheType `json:"token_cache_type" pflag:",Specifices the token cache type to use for fetching / saving auth tokens."`

Check failure on line 28 in flytectl/cmd/config/config.go

View workflow job for this annotation

GitHub Actions / Check for spelling errors

Specifices ==> Specifies
}

// OutputFormat will return output format
Expand Down
20 changes: 16 additions & 4 deletions flytectl/cmd/core/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"github.com/flyteorg/flyte/flytectl/cmd/config"
"github.com/flyteorg/flyte/flytectl/pkg/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
Expand Down Expand Up @@ -70,13 +71,24 @@
return cmdEntry.CmdFunc(ctx, args, CommandContext{})
}

var tokenCache cache.TokenCache
svcUser := fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser)
switch config.GetConfig().TokenCacheType {
case cache.TokenCacheTypeFilesystem:
tokenCache = pkce.NewtokenCacheFilesystemProvider(svcUser)

Check warning on line 78 in flytectl/cmd/core/cmd.go

View check run for this annotation

Codecov / codecov/patch

flytectl/cmd/core/cmd.go#L77-L78

Added lines #L77 - L78 were not covered by tests
case cache.TokenCacheTypeKeyring:
fallthrough
default:
tokenCache = pkce.TokenCacheKeyringProvider{
ServiceUser: svcUser,
ServiceName: pkce.KeyRingServiceName,
}
}

cmdCtx := NewCommandContextNoClient(cmd.OutOrStdout())
if !cmdEntry.DisableFlyteClient {
clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).
WithTokenCache(pkce.TokenCacheKeyringProvider{
ServiceUser: fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser),
ServiceName: pkce.KeyRingServiceName,
}).Build(ctx)
WithTokenCache(tokenCache).Build(ctx)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions flytectl/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/flyteorg/flyte/flytectl/cmd/version"
f "github.com/flyteorg/flyte/flytectl/pkg/filesystemutils"
"github.com/flyteorg/flyte/flytectl/pkg/printer"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
stdConfig "github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/config/viper"

Expand Down Expand Up @@ -57,6 +58,7 @@ func newRootCmd() *cobra.Command {
rootCmd.PersistentFlags().StringVarP(&(config.GetConfig().Domain), "domain", "d", "", "Specifies the Flyte project's domain.")
rootCmd.PersistentFlags().StringVarP(&(config.GetConfig().Output), "output", "o", printer.OutputFormatTABLE.String(), fmt.Sprintf("Specifies the output type - supported formats %s. NOTE: dot, doturl are only supported for Workflow", printer.OutputFormats()))
rootCmd.PersistentFlags().BoolVarP(&(config.GetConfig().Interactive), "interactive", "i", false, "Set this flag to use an interactive CLI")
rootCmd.PersistentFlags().Var(&(config.GetConfig().TokenCacheType), "token-cache-type", fmt.Sprintf("Type of token cache to use (available options are %s)", cache.AllTokenCacheTypes))

rootCmd.AddCommand(get.CreateGetCommand())
compileCmd := compile.CreateCompileCommand()
Expand Down
141 changes: 141 additions & 0 deletions flytectl/pkg/pkce/token_cache_filesystem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package pkce

import (
"errors"
"fmt"
"os"
"path/filepath"
"sync"

"golang.org/x/oauth2"

b64 "encoding/base64"
"encoding/json"

f "github.com/flyteorg/flyte/flytectl/pkg/filesystemutils"
)

// tokenCacheFilesystemProvider wraps the logic to save and retrieve tokens from the fs.
type tokenCacheFilesystemProvider struct {
ServiceUser string

// credentialsFile is the path to the file where the credentials are stored. This is
// typically $HOME/.flyte/credentials.json but embedded as a private field for tests.
credentialsFile string

mu sync.RWMutex
}

func NewtokenCacheFilesystemProvider(serviceUser string) *tokenCacheFilesystemProvider {
return &tokenCacheFilesystemProvider{
ServiceUser: serviceUser,
credentialsFile: f.FilePathJoin(f.UserHomeDir(), ".flyte", "credentials.json"),

Check warning on line 32 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L29-L32

Added lines #L29 - L32 were not covered by tests
}
}

type credentials map[string]*oauth2.Token

func (c credentials) MarshalJSON() ([]byte, error) {
m := make(map[string]string)
for k, v := range c {
b, err := json.Marshal(v)
if err != nil {
return nil, err

Check warning on line 43 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L43

Added line #L43 was not covered by tests
}
m[k] = b64.StdEncoding.EncodeToString(b)
}
return json.Marshal(m)
}

func (c credentials) UnmarshalJSON(b []byte) error {
m := make(map[string]string)
if err := json.Unmarshal(b, &m); err != nil {
return err

Check warning on line 53 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L53

Added line #L53 was not covered by tests
}
for k, v := range m {
s, err := b64.StdEncoding.DecodeString(v)
if err != nil {
return err

Check warning on line 58 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L58

Added line #L58 was not covered by tests
}
tk := &oauth2.Token{}
if err = json.Unmarshal(s, tk); err != nil {
return err

Check warning on line 62 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L62

Added line #L62 was not covered by tests
}
c[k] = tk
}
return nil
}

func (t *tokenCacheFilesystemProvider) SaveToken(token *oauth2.Token) error {
t.mu.Lock()
defer t.mu.Unlock()

if token.AccessToken == "" {
return fmt.Errorf("cannot save empty token with expiration %v", token.Expiry)
}

dir := filepath.Dir(t.credentialsFile)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("creating base directory (%s) for credentials: %s", dir, err.Error())

Check warning on line 79 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L79

Added line #L79 was not covered by tests
}

creds, err := t.getExistingCredentials()
if err != nil {
return err

Check warning on line 84 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L84

Added line #L84 was not covered by tests
}
creds[t.ServiceUser] = token

tmp, err := os.CreateTemp("", "flytectl")
if err != nil {
return fmt.Errorf("creating tmp file for credentials update: %s", err.Error())

Check warning on line 90 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L90

Added line #L90 was not covered by tests
}
defer os.Remove(tmp.Name())

b, err := json.Marshal(creds)
if err != nil {
return fmt.Errorf("marshalling credentials: %s", err.Error())

Check warning on line 96 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L96

Added line #L96 was not covered by tests
}
if _, err := tmp.Write(b); err != nil {
return fmt.Errorf("writing updated credentials to tmp file: %s", err.Error())

Check warning on line 99 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L99

Added line #L99 was not covered by tests
}

if err = os.Rename(tmp.Name(), t.credentialsFile); err != nil {
return fmt.Errorf("updating credentials via tmp file rename: %s", err.Error())

Check warning on line 103 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L103

Added line #L103 was not covered by tests
}

return nil
}

func (t *tokenCacheFilesystemProvider) GetToken() (*oauth2.Token, error) {
t.mu.RLock()
defer t.mu.RUnlock()

creds, err := t.getExistingCredentials()
if err != nil {
return nil, err

Check warning on line 115 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L115

Added line #L115 was not covered by tests
}

if token, ok := creds[t.ServiceUser]; ok {
return token, nil
}

return nil, errors.New("token does not exist")
}

func (t *tokenCacheFilesystemProvider) getExistingCredentials() (credentials, error) {
creds := credentials{}
if _, err := os.Stat(t.credentialsFile); errors.Is(err, os.ErrNotExist) {
return creds, nil
}

b, err := os.ReadFile(t.credentialsFile)
if err != nil {
return nil, fmt.Errorf("reading existing credentials: %s", err.Error())

Check warning on line 133 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L133

Added line #L133 was not covered by tests
}

if err = json.Unmarshal(b, &creds); err != nil {
return nil, fmt.Errorf("unmarshalling credentials: %s", err.Error())

Check warning on line 137 in flytectl/pkg/pkce/token_cache_filesystem.go

View check run for this annotation

Codecov / codecov/patch

flytectl/pkg/pkce/token_cache_filesystem.go#L137

Added line #L137 was not covered by tests
}

return creds, nil
}
96 changes: 96 additions & 0 deletions flytectl/pkg/pkce/token_cache_filesystem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package pkce

import (
"encoding/json"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func TestSaveAndGetTokenFilesystem(t *testing.T) {
setup := func(t *testing.T) *tokenCacheFilesystemProvider {
t.Helper()
// Everything inside the directory is automatically cleaned up by the test runner.
dir := t.TempDir()
tokenCacheProvider := &tokenCacheFilesystemProvider{
ServiceUser: "testServiceUser",
credentialsFile: filepath.Join(dir, "credentials.json"),
}
return tokenCacheProvider
}

t.Run("Valid Save/Get Token", func(t *testing.T) {
tokenCacheProvider := setup(t)

plan, err := os.ReadFile("testdata/token.json")
require.NoError(t, err)

var tokenData oauth2.Token
err = json.Unmarshal(plan, &tokenData)
require.NoError(t, err)

err = tokenCacheProvider.SaveToken(&tokenData)
require.NoError(t, err)

var savedToken *oauth2.Token
savedToken, err = tokenCacheProvider.GetToken()
require.NoError(t, err)

assert.NotNil(t, savedToken)
assert.Equal(t, tokenData.AccessToken, savedToken.AccessToken)
assert.Equal(t, tokenData.TokenType, savedToken.TokenType)
assert.Equal(t, tokenData.Expiry, savedToken.Expiry)
})

t.Run("Empty access token Save", func(t *testing.T) {
tokenCacheProvider := setup(t)

plan, err := os.ReadFile("testdata/empty_access_token.json")
require.NoError(t, err)

var tokenData oauth2.Token
err = json.Unmarshal(plan, &tokenData)
require.NoError(t, err)

err = tokenCacheProvider.SaveToken(&tokenData)
assert.Error(t, err)
})

t.Run("Different service name", func(t *testing.T) {
tokenCacheProvider := setup(t)

plan, err := os.ReadFile("testdata/token.json")
require.NoError(t, err)

var tokenData oauth2.Token
err = json.Unmarshal(plan, &tokenData)
require.NoError(t, err)

err = tokenCacheProvider.SaveToken(&tokenData)
require.NoError(t, err)

tokenCacheProvider2 := setup(t)

var savedToken *oauth2.Token
savedToken, err = tokenCacheProvider2.GetToken()
assert.Error(t, err)
assert.Nil(t, savedToken)

err = tokenCacheProvider2.SaveToken(&tokenData)
require.NoError(t, err)

// new token exists
savedToken, err = tokenCacheProvider2.GetToken()
require.NoError(t, err)
assert.NotNil(t, savedToken)

// token for different service name still exists
savedToken, err = tokenCacheProvider.GetToken()
require.NoError(t, err)
assert.NotNil(t, savedToken)
})
}
45 changes: 44 additions & 1 deletion flyteidl/clients/go/admin/cache/token_cache.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
package cache

import "golang.org/x/oauth2"
import (
"fmt"
"slices"
"strings"

"golang.org/x/oauth2"
)

//go:generate mockery -all -case=underscore

// TokenCacheType defines the type of token cache implementation.
type TokenCacheType string

const (
// TokenCacheTypeKeyring represents the token cache implementation using the OS's keyring.
TokenCacheTypeKeyring TokenCacheType = "keyring"
// TokenCacheTypeInMemory represents the token cache implementation using an in-memory cache.
TokenCacheTypeInMemory = "inmemory"
// TokenCacheTypeFilesystem represents the token cache implementation using the local filesystem.
TokenCacheTypeFilesystem = "filesystem"
)

var AllTokenCacheTypes = []TokenCacheType{TokenCacheTypeKeyring, TokenCacheTypeInMemory, TokenCacheTypeFilesystem}

// String implements pflag.Value interface.
func (t *TokenCacheType) String() string {
if t == nil {
return ""
}
return string(*t)
}

// Set implements pflag.Value interface.
func (t *TokenCacheType) Set(value string) error {
if slices.Contains(AllTokenCacheTypes, TokenCacheType(strings.ToLower(value))) {
*t = TokenCacheType(value)
return nil
}

return fmt.Errorf("%s is an unrecognized token cache type (supported types %v)", value, AllTokenCacheTypes)
}

// Type implements pflag.Value interface.
func (t *TokenCacheType) Type() string {
return "token-cache-type"
}

// TokenCache defines the interface needed to cache and retrieve oauth tokens.
type TokenCache interface {
// SaveToken saves the token securely to cache.
Expand Down
Loading