Skip to content

Commit

Permalink
fix: modify credential refresh to support stacked contexts (#856)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Sep 23, 2024
1 parent 2eafb08 commit ec0c019
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
4 changes: 4 additions & 0 deletions pkg/credentials/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ func (s NoopStore) Add(context.Context, Credential) error {
return nil
}

func (s NoopStore) Refresh(context.Context, Credential) error {
return nil
}

func (s NoopStore) Remove(context.Context, string) error {
return nil
}
Expand Down
21 changes: 21 additions & 0 deletions pkg/credentials/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"path/filepath"
"regexp"
"slices"
"strings"

"github.com/docker/cli/cli/config/credentials"
Expand All @@ -26,6 +27,7 @@ type CredentialBuilder interface {
type CredentialStore interface {
Get(ctx context.Context, toolName string) (*Credential, bool, error)
Add(ctx context.Context, cred Credential) error
Refresh(ctx context.Context, cred Credential) error
Remove(ctx context.Context, toolName string) error
List(ctx context.Context) ([]Credential, error)
}
Expand Down Expand Up @@ -95,6 +97,8 @@ func (s Store) Get(ctx context.Context, toolName string) (*Credential, bool, err
return &cred, true, nil
}

// Add adds a new credential to the credential store.
// Any context set on the credential object will be overwritten with the first context of the credential store.
func (s Store) Add(ctx context.Context, cred Credential) error {
first := first(s.credCtxs)
if first == AllCredentialContexts {
Expand All @@ -113,6 +117,23 @@ func (s Store) Add(ctx context.Context, cred Credential) error {
return store.Store(auth)
}

// Refresh updates an existing credential in the credential store.
func (s Store) Refresh(ctx context.Context, cred Credential) error {
if !slices.Contains(s.credCtxs, cred.Context) {
return fmt.Errorf("context %q not in list of valid contexts for this credential store", cred.Context)
}

store, err := s.getStore(ctx)
if err != nil {
return err
}
auth, err := cred.toDockerAuthConfig()
if err != nil {
return err
}
return store.Store(auth)
}

func (s Store) Remove(ctx context.Context, toolName string) error {
first := first(s.credCtxs)
if len(s.credCtxs) > 1 || first == AllCredentialContexts {
Expand Down
44 changes: 32 additions & 12 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -854,8 +854,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}

var (
c *credentials.Credential
exists bool
c *credentials.Credential
resultCredential credentials.Credential
exists bool
refresh bool
)

rm := runtimeWithLogger(callCtx, monitor, r.runtimeManager)
Expand Down Expand Up @@ -886,6 +888,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
if !exists || c.IsExpired() {
// If the existing credential is expired, we need to provide it to the cred tool through the environment.
if exists && c.IsExpired() {
refresh = true
credJSON, err := json.Marshal(c)
if err != nil {
return nil, fmt.Errorf("failed to marshal credential: %w", err)
Expand Down Expand Up @@ -916,39 +919,56 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
continue
}

if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
if err := json.Unmarshal([]byte(*res.Result), &resultCredential); err != nil {
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
}
c.ToolName = credName
c.Type = credentials.CredentialTypeTool
resultCredential.ToolName = credName
resultCredential.Type = credentials.CredentialTypeTool

if refresh {
// If this is a credential refresh, we need to make sure we use the same context.
resultCredential.Context = c.Context
} else {
// If it is a new credential, let the credential store determine the context.
resultCredential.Context = ""
}

isEmpty := true
for _, v := range c.Env {
for _, v := range resultCredential.Env {
if v != "" {
isEmpty = false
break
}
}

if !c.Ephemeral {
if !resultCredential.Ephemeral {
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
} else {
if refresh {
err = r.credStore.Refresh(callCtx.Ctx, resultCredential)
} else {
err = r.credStore.Add(callCtx.Ctx, resultCredential)
}
if err != nil {
return nil, fmt.Errorf("failed to save credential for tool %s: %w", toolName, err)
}
}
} else {
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
}
}
} else {
resultCredential = *c
}

if c.ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration.After(*c.ExpiresAt)) {
nearestExpiration = c.ExpiresAt
if resultCredential.ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration.After(*resultCredential.ExpiresAt)) {
nearestExpiration = resultCredential.ExpiresAt
}

for k, v := range c.Env {
for k, v := range resultCredential.Env {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
}
Expand Down

0 comments on commit ec0c019

Please sign in to comment.