diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ad58de3d..5999a874 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ jobs: contents: write steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 309fe47f..7cb7875a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,7 @@ jobs: name: lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v4 with: diff --git a/.github/workflows/website.yaml b/.github/workflows/website.yaml index 6cf946b9..a1580e63 100644 --- a/.github/workflows/website.yaml +++ b/.github/workflows/website.yaml @@ -19,7 +19,7 @@ jobs: run: working-directory: website steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Node uses: actions/setup-node@v3 diff --git a/.gitignore b/.gitignore index 6e3d0e2f..2034cfe8 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ kubeconfig /examples/rrsa/kaniko-in-ack/deploy.yaml /ci/ossutil/ossutil /cputil +.terraform/ +.terraform.* +terraform.tfstate* diff --git a/Makefile b/Makefile index dfd1000d..6b14178a 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,8 @@ build: .PHONY: test test: - go test -v ./... + go test -race -v ./... + cd pkg/credentials/provider && go test -race -v ./... .PHONY: e2e e2e: diff --git a/examples/rrsa/aliyuncli-demo/README.md b/examples/rrsa/aliyuncli-demo/README.md index e307250e..87129e1f 100644 --- a/examples/rrsa/aliyuncli-demo/README.md +++ b/examples/rrsa/aliyuncli-demo/README.md @@ -1,6 +1,24 @@ # aliyun cli demo -## Usage +config.json: + +``` +{ + "current": "default", + "profiles": [ + { + "name": "default", + "mode": "External", + "region_id": "cn-hangzhou", + "process_command": "ack-ram-tool export-credentials --ignore-aliyun-cli-credentials --log-level=ERROR", + "credentials_uri": "" + } + ], + "meta_path": "" +} +``` + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/aliyunlogcli-demo/README.md b/examples/rrsa/aliyunlogcli-demo/README.md index c85bcae9..311aea05 100644 --- a/examples/rrsa/aliyunlogcli-demo/README.md +++ b/examples/rrsa/aliyunlogcli-demo/README.md @@ -1,6 +1,11 @@ # aliyunlog cli demo -## Usage +``` +ack-ram-tool export-credentials --format=environment-variables -- \ + aliyunlog log list_project --region-endpoint=cn-hangzhou.log.aliyuncs.com +``` + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/cpp-demo/cpp-sdk/README.md b/examples/rrsa/cpp-demo/cpp-sdk/README.md index 97ef627f..5dbdd498 100644 --- a/examples/rrsa/cpp-demo/cpp-sdk/README.md +++ b/examples/rrsa/cpp-demo/cpp-sdk/README.md @@ -1,6 +1,6 @@ # cpp-sdk -## Usage +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/go-sdk/README.md b/examples/rrsa/go-sdk/README.md index ef44d7e4..4d9a8586 100644 --- a/examples/rrsa/go-sdk/README.md +++ b/examples/rrsa/go-sdk/README.md @@ -2,7 +2,14 @@ Using [Alibaba Could Go SDK](https://github.com/aliyun/alibabacloud-go-sdk) with RRSA Auth. -## Usage +``` +go get github.com/aliyun/credentials-go@v1.2.7 +``` + +https://github.com/aliyun/credentials-go + + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/java-sdk/README.md b/examples/rrsa/java-sdk/README.md index 7dc54b5d..2574a004 100644 --- a/examples/rrsa/java-sdk/README.md +++ b/examples/rrsa/java-sdk/README.md @@ -2,8 +2,18 @@ Using [Alibaba Could Java SDK](https://github.com/aliyun/alibabacloud-java-sdk) with RRSA Auth. +``` + + com.aliyun + credentials-java + 0.2.12 + +``` + +https://github.com/aliyun/credentials-java + -## Usage +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/kaniko-in-ack/README.md b/examples/rrsa/kaniko-in-ack/README.md index c2087109..7b6ec34d 100644 --- a/examples/rrsa/kaniko-in-ack/README.md +++ b/examples/rrsa/kaniko-in-ack/README.md @@ -5,7 +5,7 @@ Running kaniko in ACK: * build image with kaniko * push image to the ACR with RRSA Auth -## Usage +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/log-go-sdk/README.md b/examples/rrsa/log-go-sdk/README.md index 0afccb83..642e47d2 100644 --- a/examples/rrsa/log-go-sdk/README.md +++ b/examples/rrsa/log-go-sdk/README.md @@ -2,7 +2,8 @@ Using [aliyun-log-go-sdk](https://github.com/aliyun/aliyun-log-go-sdk) with RRSA Auth. -## Usage + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/log-go-sdk/go.mod b/examples/rrsa/log-go-sdk/go.mod index d6157ad5..7871d22f 100644 --- a/examples/rrsa/log-go-sdk/go.mod +++ b/examples/rrsa/log-go-sdk/go.mod @@ -3,7 +3,7 @@ module github.com/AliyunContainerService/ack-ram-tool/examples/rrsa/log-go-sdk go 1.16 require ( - github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.7.1 + github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.9.0 github.com/aliyun/aliyun-log-go-sdk v0.1.54 github.com/stretchr/testify v1.5.1 // indirect golang.org/x/net v0.7.0 // indirect diff --git a/examples/rrsa/log-go-sdk/go.sum b/examples/rrsa/log-go-sdk/go.sum index 4952b96a..2c3462ac 100644 --- a/examples/rrsa/log-go-sdk/go.sum +++ b/examples/rrsa/log-go-sdk/go.sum @@ -1,7 +1,7 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.7.1 h1:PXpSLU9ghgbUvDgRSr2N+SPHV5Ze0dYoqwGM4LSyfc4= -github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.7.1/go.mod h1:ULtI7L9xkNeJ07YNqSeT5EhjQAl1CpTgPcUn4KoNcuc= +github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.9.0 h1:PqNYfVOnnbTN9d2X8Hg8JCcp7H53YwMWFA6//AYDAg0= +github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.9.0/go.mod h1:ULtI7L9xkNeJ07YNqSeT5EhjQAl1CpTgPcUn4KoNcuc= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= diff --git a/examples/rrsa/log-java-sdk/README.md b/examples/rrsa/log-java-sdk/README.md index 26e12a2b..18255807 100644 --- a/examples/rrsa/log-java-sdk/README.md +++ b/examples/rrsa/log-java-sdk/README.md @@ -2,7 +2,18 @@ Using [aliyun-log-java-sdk](https://github.com/aliyun/aliyun-log-java-sdk) with RRSA Auth. -## Usage +``` + + com.aliyun + credentials-java + 0.2.12 + +``` + +https://github.com/aliyun/credentials-java + + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/nodejs-sdk/README.md b/examples/rrsa/nodejs-sdk/README.md index acbe774b..033e90f2 100644 --- a/examples/rrsa/nodejs-sdk/README.md +++ b/examples/rrsa/nodejs-sdk/README.md @@ -2,7 +2,14 @@ Using [Alibaba Could Node.js/TypeScript SDK](https://github.com/aliyun/alibabacloud-typescript-sdk) with RRSA Auth. -## Usage +``` +npm install @alicloud/credentials +``` + +https://github.com/aliyun/credentials-nodejs + + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/oss-go-sdk/README.md b/examples/rrsa/oss-go-sdk/README.md index 7c853b33..b2b306dd 100644 --- a/examples/rrsa/oss-go-sdk/README.md +++ b/examples/rrsa/oss-go-sdk/README.md @@ -2,7 +2,14 @@ Using [aliyun-oss-go-sdk](https://github.com/aliyun/aliyun-oss-go-sdk) with RRSA Auth. -## Usage +``` +go get github.com/aliyun/credentials-go@v1.2.7 +``` + +https://github.com/aliyun/credentials-go + + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/oss-java-sdk/README.md b/examples/rrsa/oss-java-sdk/README.md index 0fbb41b4..8f3ca1b3 100644 --- a/examples/rrsa/oss-java-sdk/README.md +++ b/examples/rrsa/oss-java-sdk/README.md @@ -2,7 +2,18 @@ Using [aliyun-oss-java-sdk](https://github.com/aliyun/aliyun-oss-java-sdk) with RRSA Auth. -## Usage +``` + + com.aliyun + credentials-java + 0.2.12 + +``` + +https://github.com/aliyun/credentials-java + + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/python3-sdk/README.md b/examples/rrsa/python3-sdk/README.md index 1f46c2d4..808a471f 100644 --- a/examples/rrsa/python3-sdk/README.md +++ b/examples/rrsa/python3-sdk/README.md @@ -2,7 +2,14 @@ Using [Alibaba Could Python 3 SDK](https://github.com/aliyun/alibabacloud-python-sdk) with RRSA Auth. -## Usage +``` +pip install alibabacloud_credentials>=0.3.1 +``` + +https://github.com/aliyun/credentials-python + + +## Demo 1. Enable RRSA: diff --git a/examples/rrsa/terraform-demo/rrsa-config/README.md b/examples/rrsa/terraform-demo/rrsa-config/README.md new file mode 100644 index 00000000..4fa0f7ce --- /dev/null +++ b/examples/rrsa/terraform-demo/rrsa-config/README.md @@ -0,0 +1,22 @@ +# RRSA Configuration Via Terraform + + +``` +aliyun/terraform-provider-alicloud > v1.171.0 +``` + +https://registry.terraform.io/providers/aliyun/alicloud/latest + + + +## Demo + +``` +export ALICLOUD_ACCESS_KEY= +export ALICLOUD_SECRET_KEY= +export ALICLOUD_REGION="cn-hangzhou" + +terraform init +terraform plan +terraform apply +``` diff --git a/examples/rrsa/terraform-demo/rrsa-config/main.tf b/examples/rrsa/terraform-demo/rrsa-config/main.tf new file mode 100644 index 00000000..03525b43 --- /dev/null +++ b/examples/rrsa/terraform-demo/rrsa-config/main.tf @@ -0,0 +1,121 @@ +provider "alicloud" { +} + +variable "k8s_name_prefix" { + description = "The name prefix used to create ASK cluster." + default = "ask-rrsa-example" +} + +resource "random_uuid" "this" {} + + +locals { + k8s_name_ask = substr(join("-", [var.k8s_name_prefix,"ask"]), 0, 63) + new_vpc_name = "tf-vpc-172-16" + new_vsw_name = "tf-vswitch-172-16-0" +} + +data "alicloud_zones" "default" { + available_resource_creation = "VSwitch" +} + +resource "alicloud_vpc" "vpc" { + vpc_name = local.new_vpc_name + cidr_block = "172.16.0.0/12" +} + +resource "alicloud_vswitch" "vsw" { + vswitch_name = local.new_vsw_name + vpc_id = alicloud_vpc.vpc.id + cidr_block = cidrsubnet(alicloud_vpc.vpc.cidr_block, 8, 8) + zone_id = data.alicloud_zones.default.zones[0].id +} + + +resource "alicloud_cs_serverless_kubernetes" "serverless" { + name = local.k8s_name_ask + version = "1.26.3-aliyun.1" + cluster_spec = "ack.pro.small" + vpc_id = alicloud_vpc.vpc.id + vswitch_ids = split(",", join(",", alicloud_vswitch.vsw.*.id)) + new_nat_gateway = false + endpoint_public_access_enabled = false + deletion_protection = false + load_balancer_spec = "slb.s2.small" + time_zone = "Asia/Shanghai" + service_cidr = "10.13.0.0/16" + service_discovery_types = ["CoreDNS"] + + # Enable RRSA + enable_rrsa = true +} + + +# k8s service account info +variable "k8s_namespace" { + default = "test-rrsa-ns" +} +variable "k8s_service_account" { + default = "foo-bar-manager-sa" +} + +# Create a new RAM Role. +resource "alicloud_ram_role" "role" { + name = "rrsa-demo-${alicloud_cs_serverless_kubernetes.serverless.id}" + document = <= c.runtimeSwitchCacheDuration { + c.logger().Debug(fmt.Sprintf("%s trigger select provider again", c.logPrefix)) + return nil + } + + return p +} + +func (c *ChainProvider) setCurrentProvider(p CredentialsProvider) { + c.lock.Lock() + defer c.lock.Unlock() + + prePT := fmt.Sprintf("%T", c.currentProvider) + pT := fmt.Sprintf("%T", p) + if prePT != pT { + c.logger().Info(fmt.Sprintf("%s switch to using new provider: %s -> %s", c.logPrefix, prePT, pT)) + } + + c.lastSelectProviderTime = time.Now() + c.currentProvider = p } func (c *ChainProvider) logger() Logger { @@ -59,31 +148,45 @@ func (c *ChainProvider) logger() Logger { } type DefaultChainProviderOptions struct { + EnableRuntimeSwitch bool + RuntimeSwitchCacheDuration time.Duration + STSEndpoint string ExpiryWindow time.Duration RefreshPeriod time.Duration Logger Logger + + logPrefix string } func NewDefaultChainProvider(opts DefaultChainProviderOptions) *ChainProvider { - p := NewChainProvider( - NewEnvProvider(EnvProviderOptions{}), - NewOIDCProvider(OIDCProviderOptions{ - STSEndpoint: opts.STSEndpoint, - ExpiryWindow: opts.ExpiryWindow, - RefreshPeriod: opts.RefreshPeriod, - Logger: opts.Logger, - }), - NewEncryptedFileProvider(EncryptedFileProviderOptions{ - ExpiryWindow: opts.ExpiryWindow, - RefreshPeriod: opts.RefreshPeriod, - Logger: opts.Logger, - }), - NewECSMetadataProvider(ECSMetadataProviderOptions{ - ExpiryWindow: opts.ExpiryWindow, - RefreshPeriod: opts.RefreshPeriod, - Logger: opts.Logger, - }), + opts.applyDefaults() + + p := NewChainProviderWithOptions( + []CredentialsProvider{ + NewEnvProvider(EnvProviderOptions{}), + NewOIDCProvider(OIDCProviderOptions{ + STSEndpoint: opts.STSEndpoint, + ExpiryWindow: opts.ExpiryWindow, + RefreshPeriod: opts.RefreshPeriod, + Logger: opts.Logger, + }), + NewEncryptedFileProvider(EncryptedFileProviderOptions{ + ExpiryWindow: opts.ExpiryWindow, + RefreshPeriod: opts.RefreshPeriod, + Logger: opts.Logger, + }), + NewECSMetadataProvider(ECSMetadataProviderOptions{ + ExpiryWindow: opts.ExpiryWindow, + RefreshPeriod: opts.RefreshPeriod, + Logger: opts.Logger, + }), + }, + ChainProviderOptions{ + EnableRuntimeSwitch: opts.EnableRuntimeSwitch, + RuntimeSwitchCacheDuration: opts.RuntimeSwitchCacheDuration, + logPrefix: opts.logPrefix, + }, ) p.Logger = opts.Logger return p @@ -100,3 +203,18 @@ func DefaultChainProviderWithLogger(l Logger) *ChainProvider { Logger: l, }) } + +func (o *ChainProviderOptions) applyDefaults() { + if o.RuntimeSwitchCacheDuration <= 0 { + o.RuntimeSwitchCacheDuration = defaultRuntimeSwitchCacheDuration + } + if o.logPrefix == "" { + o.logPrefix = "[ChainProvider]" + } +} + +func (o *DefaultChainProviderOptions) applyDefaults() { + if o.logPrefix == "" { + o.logPrefix = "[DefaultChainProvider]" + } +} diff --git a/pkg/credentials/provider/chain_provider_test.go b/pkg/credentials/provider/chain_provider_test.go new file mode 100644 index 00000000..8ef13009 --- /dev/null +++ b/pkg/credentials/provider/chain_provider_test.go @@ -0,0 +1,88 @@ +package provider + +import ( + "context" + "fmt" + "net/http" + "sync/atomic" + "testing" + "time" +) + +func TestChainProvider_Credentials_success(t *testing.T) { + p1 := NewAccessKeyProvider("", "") + p2 := NewAccessKeyProvider("", "") + p3 := NewSTSTokenProvider("ak3", "sk3", "sts3") + cp := NewChainProvider(p1, p2, p3) + + cred, err := cp.Credentials(context.TODO()) + if err != nil { + t.Errorf("should no error: %+v", err) + } + if cred.AccessKeyId != "ak3" || + cred.AccessKeySecret != "sk3" || + cred.SecurityToken != "sts3" { + t.Errorf("unexpect ret: %+v", *cred) + } +} + +func TestChainProvider_Credentials_no_provider(t *testing.T) { + cp := NewChainProvider(NewAccessKeyProvider("", "")) + cred, err := cp.Credentials(context.TODO()) + if err == nil { + t.Errorf("should return error: %+v", err) + } + t.Log(err) + if cred != nil { + t.Errorf("should return nil: %+v", *cred) + } +} + +func TestChainProvider_Stop(t *testing.T) { + var callCount int32 + s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&callCount, 1) + fmt.Fprint(w, ` +{} +`) + }) + + p := NewRoleArnProvider( + NewAccessKeyProvider("ak1", "sk1"), + "role_arn", + RoleArnProviderOptions{ + STSEndpoint: s.URL, + RefreshPeriod: time.Millisecond * 100, + Logger: TLogger{t: t}, + }, + ) + + cp := NewChainProvider(p) + cp.Logger = TLogger{t: t} + cp.Credentials(context.TODO()) + + cv := atomic.LoadInt32(&callCount) + if cv < 1 { + t.Errorf("callCount should >= 1: %v", cv) + } + + time.Sleep(time.Second) + cv = atomic.LoadInt32(&callCount) + if cv <= 1 { + t.Errorf("callCount should > 1: %v", cv) + } + + cp.Stop(context.TODO()) + time.Sleep(time.Second) + curr := atomic.LoadInt32(&callCount) + time.Sleep(time.Second) + + cv = atomic.LoadInt32(&callCount) + if cv != curr { + t.Errorf("callCount should == %v: %v", curr, cv) + } + + cp.Stop(context.TODO()) + cp.Stop(context.TODO()) + cp.Stop(context.TODO()) +} diff --git a/pkg/credentials/provider/ecsmetadata_provider.go b/pkg/credentials/provider/ecsmetadata_provider.go index 33263826..3acef2e1 100644 --- a/pkg/credentials/provider/ecsmetadata_provider.go +++ b/pkg/credentials/provider/ecsmetadata_provider.go @@ -72,6 +72,10 @@ func (e *ECSMetadataProvider) Credentials(ctx context.Context) (*Credentials, er return e.u.Credentials(ctx) } +func (e *ECSMetadataProvider) Stop(ctx context.Context) { + e.u.Stop(ctx) +} + type ecsMetadataStsResponse struct { AccessKeyId string `json:"AccessKeyId"` AccessKeySecret string `json:"AccessKeySecret"` @@ -136,6 +140,7 @@ func (e *ECSMetadataProvider) getMedataToken(ctx context.Context) (string, error return e.metadataToken, nil } + e.logger().Debug("start to get metadata token") h := http.Header{} h.Set("X-aliyun-ecs-metadata-token-ttl-seconds", fmt.Sprintf("%d", e.metadataTokenTTLSeconds)) body, err := e.getMedataData(ctx, http.MethodPut, "/latest/api/token", h) diff --git a/pkg/credentials/provider/ecsmetadata_provider_test.go b/pkg/credentials/provider/ecsmetadata_provider_test.go new file mode 100644 index 00000000..f0dec92b --- /dev/null +++ b/pkg/credentials/provider/ecsmetadata_provider_test.go @@ -0,0 +1,49 @@ +package provider + +import ( + "context" + "fmt" + "net/http" + "testing" +) + +func TestECSMetadataProvider_Credentials_success(t *testing.T) { + s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + { + fmt.Fprint(w, "token-xxx") + return + } + default: + { + fmt.Fprint(w, ` +{ + "AccessKeyId": "ak", + "AccessKeySecret": "sk", + "SecurityToken": "tt", + "Expiration": "2206-01-02T15:04:05Z" +} +`) + } + } + }) + + p := NewECSMetadataProvider(ECSMetadataProviderOptions{ + Endpoint: s.URL, + RoleName: "test", + }) + + cred, err := p.Credentials(context.TODO()) + if err != nil { + t.Log(err) + t.Errorf("should no error: %+v", err) + } + + if cred.AccessKeyId != "ak" || + cred.AccessKeySecret != "sk" || + cred.SecurityToken != "tt" || + cred.Expiration.IsZero() { + t.Errorf("got unexpected cred") + } +} diff --git a/pkg/credentials/provider/env_provider.go b/pkg/credentials/provider/env_provider.go index 58151b66..e05fa768 100644 --- a/pkg/credentials/provider/env_provider.go +++ b/pkg/credentials/provider/env_provider.go @@ -13,41 +13,58 @@ const ( ) type EnvProvider struct { - cred *Credentials - - envAccessKeyId string - envAccessKeySecret string - envSecurityToken string + cp *ChainProvider } type EnvProviderOptions struct { EnvAccessKeyId string EnvAccessKeySecret string EnvSecurityToken string + + EnvRoleArn string + EnvOIDCProviderArn string + EnvOIDCTokenFile string } func NewEnvProvider(opts EnvProviderOptions) *EnvProvider { opts.applyDefaults() - return &EnvProvider{ - cred: &Credentials{ - AccessKeyId: os.Getenv(opts.EnvAccessKeyId), - AccessKeySecret: os.Getenv(opts.EnvAccessKeySecret), - SecurityToken: os.Getenv(opts.EnvSecurityToken), - }, - envAccessKeyId: opts.EnvAccessKeyId, - envAccessKeySecret: opts.EnvAccessKeySecret, - envSecurityToken: opts.EnvSecurityToken, - } + e := &EnvProvider{} + e.cp = e.getProvider(opts) + + return e } func (e *EnvProvider) Credentials(ctx context.Context) (*Credentials, error) { - if e.cred.AccessKeyId == "" || e.cred.AccessKeySecret == "" { - return nil, NewNotEnableError(fmt.Errorf("env %s or %s is empty", - e.envAccessKeyId, e.envAccessKeySecret)) + cred, err := e.cp.Credentials(ctx) + + if err != nil { + if IsNoAvailableProviderError(err) { + return nil, NewNotEnableError(fmt.Errorf("not found credentials from env: %w", err)) + } + return nil, err } - return e.cred.DeepCopy(), nil + return cred.DeepCopy(), nil +} + +func (e *EnvProvider) getProvider(opts EnvProviderOptions) *ChainProvider { + p1 := NewSTSTokenProvider( + os.Getenv(opts.EnvAccessKeyId), + os.Getenv(opts.EnvAccessKeySecret), + os.Getenv(opts.EnvSecurityToken), + ) + p2 := NewOIDCProvider(OIDCProviderOptions{ + RoleArn: os.Getenv(opts.EnvRoleArn), + OIDCProviderArn: os.Getenv(opts.EnvOIDCProviderArn), + OIDCTokenFile: os.Getenv(opts.EnvOIDCTokenFile), + }) + p3 := NewAccessKeyProvider( + os.Getenv(opts.EnvAccessKeyId), + os.Getenv(opts.EnvAccessKeySecret), + ) + cp := NewChainProvider(p1, p2, p3) + return cp } func (o *EnvProviderOptions) applyDefaults() { @@ -60,4 +77,14 @@ func (o *EnvProviderOptions) applyDefaults() { if o.EnvSecurityToken == "" { o.EnvSecurityToken = envSecurityToken } + + if o.EnvRoleArn == "" { + o.EnvRoleArn = defaultEnvRoleArn + } + if o.EnvOIDCProviderArn == "" { + o.EnvOIDCProviderArn = defaultEnvOIDCProviderArn + } + if o.EnvOIDCTokenFile == "" { + o.EnvOIDCTokenFile = defaultEnvOIDCTokenFile + } } diff --git a/pkg/credentials/provider/env_provider_test.go b/pkg/credentials/provider/env_provider_test.go new file mode 100644 index 00000000..697267c3 --- /dev/null +++ b/pkg/credentials/provider/env_provider_test.go @@ -0,0 +1,80 @@ +package provider + +import ( + "context" + "os" + "testing" +) + +func TestEnvProvider_Credentials(t *testing.T) { + envAk := "TestEnvProvider_Credentials_AK" + envSK := "TestEnvProvider_Credentials_SK" + envToken := "TestEnvProvider_Credentials_Token" + envRoleArn := "TestEnvProvider_Credentials_Role_ARN" + envOidcP := "TestEnvProvider_Credentials_OIDC_Pro" + envOidcT := "TestEnvProvider_Credentials_OIDC_Token" + + t.Run("no env", func(t *testing.T) { + p := NewEnvProvider(EnvProviderOptions{ + EnvAccessKeyId: envAk, + EnvAccessKeySecret: envSK, + EnvSecurityToken: envToken, + EnvRoleArn: envRoleArn, + EnvOIDCProviderArn: envOidcP, + EnvOIDCTokenFile: envOidcT, + }) + cred, err := p.Credentials(context.TODO()) + if err == nil { + t.Errorf("should return error: %+v", err) + } + t.Log(err) + if cred != nil { + t.Errorf("got unexpected cred: %+v", *cred) + } + }) + + t.Run("only ak env", func(t *testing.T) { + os.Setenv(envAk, "ak") + os.Setenv(envSK, "sk") + p := NewEnvProvider(EnvProviderOptions{ + EnvAccessKeyId: envAk, + EnvAccessKeySecret: envSK, + EnvSecurityToken: envToken, + EnvRoleArn: envRoleArn, + EnvOIDCProviderArn: envOidcP, + EnvOIDCTokenFile: envOidcT, + }) + cred, err := p.Credentials(context.TODO()) + if err != nil { + t.Errorf("should no error: %+v", err) + } + if cred.AccessKeyId != "ak" || + cred.AccessKeySecret != "sk" || + cred.SecurityToken != "" { + t.Errorf("got unexpected cred: %+v", *cred) + } + }) + + t.Run("sts token env", func(t *testing.T) { + os.Setenv(envAk, "ak") + os.Setenv(envSK, "sk") + os.Setenv(envToken, "sts-token") + p := NewEnvProvider(EnvProviderOptions{ + EnvAccessKeyId: envAk, + EnvAccessKeySecret: envSK, + EnvSecurityToken: envToken, + EnvRoleArn: envRoleArn, + EnvOIDCProviderArn: envOidcP, + EnvOIDCTokenFile: envOidcT, + }) + cred, err := p.Credentials(context.TODO()) + if err != nil { + t.Errorf("should no error: %+v", err) + } + if cred.AccessKeyId != "ak" || + cred.AccessKeySecret != "sk" || + cred.SecurityToken != "sts-token" { + t.Errorf("got unexpected cred: %+v", *cred) + } + }) +} diff --git a/pkg/credentials/provider/error.go b/pkg/credentials/provider/error.go new file mode 100644 index 00000000..5343c255 --- /dev/null +++ b/pkg/credentials/provider/error.go @@ -0,0 +1,45 @@ +package provider + +type NotEnableError struct { + err error +} + +type NoAvailableProviderError struct { + err error +} + +func NewNotEnableError(err error) *NotEnableError { + return &NotEnableError{err: err} +} + +func NewNoAvailableProviderError(err error) *NoAvailableProviderError { + return &NoAvailableProviderError{err: err} +} + +func (e NotEnableError) Error() string { + return e.err.Error() +} + +func (e NoAvailableProviderError) Error() string { + return e.err.Error() +} + +func IsNotEnableError(err error) bool { + if _, ok := err.(*NotEnableError); ok { + return true + } + if _, ok := err.(NotEnableError); ok { + return true + } + return false +} + +func IsNoAvailableProviderError(err error) bool { + if _, ok := err.(*NoAvailableProviderError); ok { + return true + } + if _, ok := err.(NoAvailableProviderError); ok { + return true + } + return false +} diff --git a/pkg/credentials/provider/file_provider.go b/pkg/credentials/provider/file_provider.go index 231b6d3a..5e8eebde 100644 --- a/pkg/credentials/provider/file_provider.go +++ b/pkg/credentials/provider/file_provider.go @@ -43,6 +43,10 @@ func (f *FileProvider) Credentials(ctx context.Context) (*Credentials, error) { return f.u.Credentials(ctx) } +func (f *FileProvider) Stop(ctx context.Context) { + f.u.Stop(ctx) +} + func (f *FileProvider) getCredentials(ctx context.Context) (*Credentials, error) { data, err := os.ReadFile(f.filepath) if err != nil { diff --git a/pkg/credentials/provider/file_provider_test.go b/pkg/credentials/provider/file_provider_test.go new file mode 100644 index 00000000..bf1dad12 --- /dev/null +++ b/pkg/credentials/provider/file_provider_test.go @@ -0,0 +1,39 @@ +package provider + +import ( + "context" + "os" + "path" + "testing" + "time" +) + +func TestFileProvider_Credentials(t *testing.T) { + d, err := os.MkdirTemp("", "TestFileProvider_Credentials") + if err != nil { + t.Errorf("should not error: %+v", err) + return + } + fp := path.Join(d, "test.json") + os.WriteFile(fp, []byte("abc"), 0600) + + f := NewFileProvider(fp, func(data []byte) (*Credentials, error) { + return &Credentials{ + AccessKeyId: "ak_TestFileProvider_Credentials", + AccessKeySecret: "sk_TestFileProvider_Credentials", + SecurityToken: "", + Expiration: time.Time{}, + }, nil + }, FileProviderOptions{}) + + cred, err := f.Credentials(context.TODO()) + if err != nil { + t.Errorf("should not error: %+v", err) + return + } + if cred.AccessKeyId != "ak_TestFileProvider_Credentials" || + cred.AccessKeySecret != "sk_TestFileProvider_Credentials" { + t.Errorf("unexpected case found: %+v", *cred) + return + } +} diff --git a/pkg/credentials/provider/function_provider_test.go b/pkg/credentials/provider/function_provider_test.go new file mode 100644 index 00000000..f75bf4bf --- /dev/null +++ b/pkg/credentials/provider/function_provider_test.go @@ -0,0 +1,29 @@ +package provider + +import ( + "context" + "testing" + "time" +) + +func TestFunctionProvider_Credentials(t *testing.T) { + f := NewFunctionProvider(func(ctx context.Context) (*Credentials, error) { + return &Credentials{ + AccessKeyId: "ak_TestFunctionProvider_Credentials", + AccessKeySecret: "sk_TestFunctionProvider_Credentials", + SecurityToken: "", + Expiration: time.Time{}, + }, nil + }) + + cred, err := f.Credentials(context.TODO()) + if err != nil { + t.Errorf("should not error: %+v", err) + return + } + if cred.AccessKeyId != "ak_TestFunctionProvider_Credentials" || + cred.AccessKeySecret != "sk_TestFunctionProvider_Credentials" { + t.Errorf("unexpected case found: %+v", *cred) + return + } +} diff --git a/pkg/credentials/provider/oidc_provider.go b/pkg/credentials/provider/oidc_provider.go index 73800685..531cce0e 100644 --- a/pkg/credentials/provider/oidc_provider.go +++ b/pkg/credentials/provider/oidc_provider.go @@ -19,6 +19,8 @@ const ( defaultEnvRoleArn = "ALIBABA_CLOUD_ROLE_ARN" defaultEnvOIDCProviderArn = "ALIBABA_CLOUD_OIDC_PROVIDER_ARN" defaultEnvOIDCTokenFile = "ALIBABA_CLOUD_OIDC_TOKEN_FILE" + + defaultExpiryWindowForAssumeRole = time.Minute * 10 ) var defaultSessionName = "default-session-name" @@ -98,12 +100,16 @@ func (o *OIDCProvider) Credentials(ctx context.Context) (*Credentials, error) { return o.u.Credentials(ctx) } +func (o *OIDCProvider) Stop(ctx context.Context) { + o.u.Stop(ctx) +} + func (o *OIDCProvider) getCredentials(ctx context.Context) (*Credentials, error) { roleArn := o.roleArn oidcProviderArn := o.oidcProviderArn tokenFile := o.oidcTokenFile if roleArn == "" || oidcProviderArn == "" || tokenFile == "" { - return nil, NewNotEnableError(errors.New("roleArn, oidcProviderArn or tokenFile is empty")) + return nil, NewNotEnableError(errors.New("roleArn, oidcProviderArn or oidcTokenFile is empty")) } tokenData, err := os.ReadFile(tokenFile) @@ -235,7 +241,7 @@ func (o *OIDCProviderOptions) applyDefaults() { o.SessionName = defaultSessionName } if o.ExpiryWindow == 0 { - o.ExpiryWindow = defaultExpiryWindow + o.ExpiryWindow = defaultExpiryWindowForAssumeRole } if o.EnvRoleArn == "" { o.EnvRoleArn = defaultEnvRoleArn diff --git a/pkg/credentials/provider/oidc_provider_test.go b/pkg/credentials/provider/oidc_provider_test.go new file mode 100644 index 00000000..99978ee8 --- /dev/null +++ b/pkg/credentials/provider/oidc_provider_test.go @@ -0,0 +1,44 @@ +package provider + +import ( + "context" + "fmt" + "net/http" + "os" + "testing" +) + +func TestOIDCProvider_Credentials_success(t *testing.T) { + s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, ` +{ + "Credentials": { + "AccessKeyId": "ak", + "AccessKeySecret": "sk", + "SecurityToken": "tt", + "Expiration": "2206-01-02T15:04:05Z" + } +} +`) + }) + + p := NewOIDCProvider(OIDCProviderOptions{ + STSEndpoint: s.URL, + RoleArn: "role_arn", + OIDCProviderArn: "oidc_arn", + OIDCTokenFile: os.Args[0], + }) + + cred, err := p.Credentials(context.TODO()) + if err != nil { + t.Log(err) + t.Errorf("should no error: %+v", err) + } + + if cred.AccessKeyId != "ak" || + cred.AccessKeySecret != "sk" || + cred.SecurityToken != "tt" || + cred.Expiration.IsZero() { + t.Errorf("got unexpected cred") + } +} diff --git a/pkg/credentials/provider/provider.go b/pkg/credentials/provider/provider.go index 80651ad7..11b9289d 100644 --- a/pkg/credentials/provider/provider.go +++ b/pkg/credentials/provider/provider.go @@ -14,18 +14,11 @@ type CredentialsProvider interface { Credentials(ctx context.Context) (*Credentials, error) } +type Stopper interface { + Stop(ctx context.Context) +} + func init() { name := path.Base(os.Args[0]) UserAgent = fmt.Sprintf("%s %s/%s ack-ram-tool/%s", name, runtime.GOOS, runtime.GOARCH, runtime.Version()) } - -type NotEnableError struct { - err error -} - -func NewNotEnableError(err error) *NotEnableError { - return &NotEnableError{err: err} -} -func (e NotEnableError) Error() string { - return fmt.Sprintf("this provider is not enabled: %s", e.err.Error()) -} diff --git a/pkg/credentials/provider/rolearn_provider.go b/pkg/credentials/provider/rolearn_provider.go index 7f607a03..0ae6ee64 100644 --- a/pkg/credentials/provider/rolearn_provider.go +++ b/pkg/credentials/provider/rolearn_provider.go @@ -69,6 +69,13 @@ func (r *RoleArnProvider) Credentials(ctx context.Context) (*Credentials, error) return r.u.Credentials(ctx) } +func (r *RoleArnProvider) Stop(ctx context.Context) { + r.u.Stop(ctx) + if s, ok := r.cp.(Stopper); ok { + s.Stop(ctx) + } +} + func (r *RoleArnProvider) getCredentials(ctx context.Context) (*Credentials, error) { return r.assumeRole(ctx, r.roleArn) } @@ -200,7 +207,7 @@ func (o *RoleArnProviderOptions) applyDefaults() { o.SessionName = defaultSessionName } if o.ExpiryWindow == 0 { - o.ExpiryWindow = defaultExpiryWindow + o.ExpiryWindow = defaultExpiryWindowForAssumeRole } if o.Logger == nil { o.Logger = defaultLog diff --git a/pkg/credentials/provider/rolearn_provider_test.go b/pkg/credentials/provider/rolearn_provider_test.go new file mode 100644 index 00000000..2f55263d --- /dev/null +++ b/pkg/credentials/provider/rolearn_provider_test.go @@ -0,0 +1,110 @@ +package provider + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRoleArnProvider_Credentials_success(t *testing.T) { + s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, ` +{ + "Credentials": { + "AccessKeyId": "ak", + "AccessKeySecret": "sk", + "SecurityToken": "tt", + "Expiration": "2206-01-02T15:04:05Z" + } +} +`) + }) + + p := NewRoleArnProvider( + NewAccessKeyProvider("ak1", "sk1"), + "role_arn", + RoleArnProviderOptions{ + STSEndpoint: s.URL, + }, + ) + + cred, err := p.Credentials(context.TODO()) + if err != nil { + t.Log(err) + t.Errorf("should no error: %+v", err) + } + + if cred.AccessKeyId != "ak" || + cred.AccessKeySecret != "sk" || + cred.SecurityToken != "tt" || + cred.Expiration.IsZero() { + t.Errorf("got unexpected cred") + } +} + +func TestRoleArnProvider_Credentials_stop_with_no_stop_method_cp(t *testing.T) { + s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, ` +{ + "Credentials": { + "AccessKeyId": "ak", + "AccessKeySecret": "sk", + "SecurityToken": "tt", + "Expiration": "2206-01-02T15:04:05Z" + } +} +`) + }) + + p := NewRoleArnProvider( + NewAccessKeyProvider("ak1", "sk1"), + "role_arn", + RoleArnProviderOptions{ + STSEndpoint: s.URL, + Logger: TLogger{t}, + }, + ) + + p.Stop(context.TODO()) + p.Stop(context.TODO()) +} + +func TestRoleArnProvider_Credentials_stop_with_stop_method_cp(t *testing.T) { + s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, ` +{ + "Credentials": { + "AccessKeyId": "ak", + "AccessKeySecret": "sk", + "SecurityToken": "tt", + "Expiration": "2206-01-02T15:04:05Z" + } +} +`) + }) + + p := NewRoleArnProvider( + NewOIDCProvider(OIDCProviderOptions{ + Logger: TLogger{t}, + }), + "role_arn", + RoleArnProviderOptions{ + STSEndpoint: s.URL, + Logger: TLogger{t}, + }, + ) + + p.Stop(context.TODO()) + p.Stop(context.TODO()) +} + +func setupHttpTestServer(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + handler(w, r) + }) + s := httptest.NewServer(mux) + return s +} diff --git a/pkg/credentials/provider/ststoken_provider_test.go b/pkg/credentials/provider/ststoken_provider_test.go new file mode 100644 index 00000000..ee197a0d --- /dev/null +++ b/pkg/credentials/provider/ststoken_provider_test.go @@ -0,0 +1,43 @@ +package provider + +import ( + "context" + "testing" + "time" +) + +func TestSTSTokenProvider_Credentials(t *testing.T) { + p := NewSTSTokenProvider("ak", "sk", "tt") + cred, err := p.Credentials(context.TODO()) + + if err != nil { + t.Log(err) + t.Errorf("should not return error: %+v", err) + } + + if cred.AccessKeyId != "ak" || + cred.AccessKeySecret != "sk" || + cred.SecurityToken != "tt" || + !cred.Expiration.IsZero() { + t.Error("cred value is not expected") + } +} + +func TestSTSTokenProvider_SetExpiration(t *testing.T) { + p := NewSTSTokenProvider("ak", "sk", "tt") + tm := time.Now() + p.SetExpiration(tm) + cred, err := p.Credentials(context.TODO()) + + if err != nil { + t.Log(err) + t.Errorf("should not return error: %+v", err) + } + + if cred.AccessKeyId != "ak" || + cred.AccessKeySecret != "sk" || + cred.SecurityToken != "tt" || + cred.Expiration.IsZero() { + t.Error("cred value is not expected") + } +} diff --git a/pkg/credentials/provider/updater.go b/pkg/credentials/provider/updater.go index 0a51285e..705f436d 100644 --- a/pkg/credentials/provider/updater.go +++ b/pkg/credentials/provider/updater.go @@ -25,6 +25,9 @@ type Updater struct { Logger Logger nowFunc func() time.Time logPrefix string + + doneCh chan struct{} + stopped bool } type UpdaterOptions struct { @@ -45,6 +48,7 @@ func NewUpdater(getter getCredentialsFunc, opts UpdaterOptions) *Updater { Logger: opts.Logger, nowFunc: time.Now, logPrefix: opts.LogPrefix, + doneCh: make(chan struct{}), } return u } @@ -57,17 +61,36 @@ func (u *Updater) Start(ctx context.Context) { go u.startRefreshLoop(ctx) } +func (u *Updater) Stop(shutdownCtx context.Context) { + u.logger().Debug(fmt.Sprintf("%s start to stop...", u.logPrefix)) + + go func() { + u.lockForCred.Lock() + defer u.lockForCred.Unlock() + if u.stopped { + return + } + u.stopped = true + close(u.doneCh) + }() + + select { + case <-shutdownCtx.Done(): + case <-u.doneCh: + } +} + func (u *Updater) startRefreshLoop(ctx context.Context) { ticket := time.NewTicker(u.refreshPeriod) defer ticket.Stop() - u.refreshCredForLoop(ctx) - loop: for { select { case <-ctx.Done(): break loop + case <-u.doneCh: + break loop case <-ticket.C: u.refreshCredForLoop(ctx) } @@ -101,7 +124,7 @@ func (u *Updater) refreshCredForLoop(ctx context.Context) { if err == nil { return } - if _, ok := err.(*NotEnableError); ok { + if IsNotEnableError(err) { return } if i < maxRetry-1 { @@ -113,7 +136,7 @@ func (u *Updater) refreshCredForLoop(ctx context.Context) { func (u *Updater) refreshCred(ctx context.Context) error { cred, err := u.getCredentials(ctx) if err != nil { - if _, ok := err.(*NotEnableError); ok { + if IsNotEnableError(err) { return err } u.logger().Error(err, fmt.Sprintf("%s refresh credentials failed: %s", u.logPrefix, err)) @@ -151,8 +174,11 @@ func (u *Updater) Expired() bool { func (u *Updater) expired(expiryDelta time.Duration) bool { exp := u.expiration() + if expiryDelta > 0 { + exp = exp.Add(-expiryDelta) + } - return exp.Add(-expiryDelta).Before(u.now()) + return exp.Before(u.now()) } func (u *Updater) expiration() time.Time { diff --git a/pkg/credentials/provider/updater_test.go b/pkg/credentials/provider/updater_test.go index 5802ddc7..59f0fb92 100644 --- a/pkg/credentials/provider/updater_test.go +++ b/pkg/credentials/provider/updater_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "testing" "time" ) @@ -25,12 +26,12 @@ func (d TLogger) Error(err error, msg string) { } func TestUpdater_refreshCredForLoop_refresh(t *testing.T) { - var callCount int + var callCount int32 fakeCred := Credentials{ Expiration: time.Now().Add(time.Minute), } u := NewUpdater(func(ctx context.Context) (*Credentials, error) { - callCount++ + atomic.AddInt32(&callCount, 1) return &fakeCred, nil }, UpdaterOptions{ ExpiryWindow: 0, @@ -39,8 +40,10 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) { }) u.refreshCredForLoop(context.TODO()) - if callCount != 1 { - t.Errorf("callCount should be 1 but got %d", callCount) + + cv := atomic.LoadInt32(&callCount) + if cv != 1 { + t.Errorf("callCount should be 1 but got %d", cv) } ret := u.Expired() if ret { @@ -48,8 +51,9 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) { } u.refreshCredForLoop(context.TODO()) - if callCount != 1 { - t.Errorf("callCount should be 1 but got %d", callCount) + cv = atomic.LoadInt32(&callCount) + if cv != 1 { + t.Errorf("callCount should be 1 but got %d", cv) } u.nowFunc = func() time.Time { @@ -62,8 +66,9 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) { fakeCred.Expiration = time.Now().Add(time.Minute * 5) u.refreshCredForLoop(context.TODO()) - if callCount != 2 { - t.Errorf("callCount should be 2 but got %d", callCount) + cv = atomic.LoadInt32(&callCount) + if cv != 2 { + t.Errorf("callCount should be 2 but got %d", cv) } ret = u.Expired() if ret { @@ -72,10 +77,10 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) { } func TestUpdater_refreshCredForLoop_erorr(t *testing.T) { - var callCount int + var callCount int32 u := NewUpdater(func(ctx context.Context) (*Credentials, error) { - callCount++ + atomic.AddInt32(&callCount, 1) return nil, errors.New("error message") }, UpdaterOptions{ ExpiryWindow: 0, @@ -84,8 +89,9 @@ func TestUpdater_refreshCredForLoop_erorr(t *testing.T) { }) u.refreshCredForLoop(context.TODO()) - if callCount != 5 { - t.Errorf("callCount should be 5 but got %d", callCount) + cv := atomic.LoadInt32(&callCount) + if cv != 5 { + t.Errorf("callCount should be 5 but got %d", cv) } ret := u.Expired() if !ret { @@ -94,12 +100,12 @@ func TestUpdater_refreshCredForLoop_erorr(t *testing.T) { } func TestUpdater_Credentials_refresh(t *testing.T) { - var callCount int + var callCount int32 fakeCred := Credentials{ Expiration: time.Now().Add(time.Minute), } u := NewUpdater(func(ctx context.Context) (*Credentials, error) { - callCount++ + atomic.AddInt32(&callCount, 1) return &fakeCred, nil }, UpdaterOptions{ ExpiryWindow: 0, @@ -107,35 +113,126 @@ func TestUpdater_Credentials_refresh(t *testing.T) { Logger: TLogger{t: t}, }) - u.Credentials(context.TODO()) - if callCount != 1 { - t.Errorf("callCount should be 1 but got %d", callCount) - } - ret := u.Expired() - if ret { - t.Errorf("should not expired") - } + t.Run("Credentials use cache", func(t *testing.T) { + u.Credentials(context.TODO()) + cv := atomic.LoadInt32(&callCount) + if cv != 1 { + t.Errorf("callCount should be 1 but got %d", cv) + } + ret := u.Expired() + if ret { + t.Errorf("should not expired") + } + + u.Credentials(context.TODO()) + cv = atomic.LoadInt32(&callCount) + if cv != 1 { + t.Errorf("callCount should be 1 but got %d", cv) + } + }) - u.Credentials(context.TODO()) - if callCount != 1 { - t.Errorf("callCount should be 1 but got %d", callCount) - } + t.Run("Credentials expired", func(t *testing.T) { + u.nowFunc = func() time.Time { + return time.Now().Add(time.Minute * 2) + } + ret := u.Expired() + if !ret { + t.Errorf("should expired") + } + }) - u.nowFunc = func() time.Time { - return time.Now().Add(time.Minute) - } - ret = u.Expired() - if !ret { - t.Errorf("should expired") - } + t.Run("not expire, should not refresh", func(t *testing.T) { + fakeCred.Expiration = time.Now().Add(time.Minute * 5) + u.Credentials(context.TODO()) + cv := atomic.LoadInt32(&callCount) + if cv != 2 { + t.Errorf("callCount should be 2 but got %d", cv) + } + ret := u.Expired() + if ret { + t.Errorf("should not expired") + } + }) +} - fakeCred.Expiration = time.Now().Add(time.Minute * 5) - u.Credentials(context.TODO()) - if callCount != 2 { - t.Errorf("callCount should be 2 but got %d", callCount) +func TestUpdater_expired(t *testing.T) { + u := &Updater{} + u.setCred(&Credentials{Expiration: time.Now().Add(time.Minute)}) + + t.Run("expiryDelta=0", func(t *testing.T) { + ret := u.expired(0) + if ret { + t.Errorf("should be false") + } + }) + + t.Run("expiryDelta > 0", func(t *testing.T) { + ret := u.expired(time.Minute * 5) + if !ret { + t.Errorf("should be true") + } + }) +} + +func TestUpdater_stop(t *testing.T) { + var callCount int32 + fakeCred := Credentials{ + Expiration: time.Now().Add(-time.Minute), } - ret = u.Expired() - if ret { - t.Errorf("should not expired") + u := NewUpdater(func(ctx context.Context) (*Credentials, error) { + atomic.AddInt32(&callCount, 1) + return &fakeCred, nil + }, UpdaterOptions{ + ExpiryWindow: 0, + RefreshPeriod: time.Millisecond * 100, + Logger: TLogger{t: t}, + }) + + u.Start(context.TODO()) + + t.Run("test-refresh", func(t *testing.T) { + time.Sleep(time.Second) + cv := atomic.LoadInt32(&callCount) + if cv < 1 { + t.Errorf("callCount should be >1 but got %d", cv) + } + }) + + t.Run("test-stop", func(t *testing.T) { + u.Stop(context.TODO()) + time.Sleep(time.Second) + + curr := atomic.LoadInt32(&callCount) + time.Sleep(time.Second) + + cv := atomic.LoadInt32(&callCount) + if cv != curr { + t.Errorf("callCount should be %d but got %d", curr, cv) + } + }) + + t.Run("test-stop-multiple-times", func(t *testing.T) { + u.Stop(context.TODO()) + u.Stop(context.TODO()) + u.Stop(context.TODO()) + }) +} + +func TestUpdater_stop_no_start(t *testing.T) { + var callCount int32 + fakeCred := Credentials{ + Expiration: time.Now().Add(-time.Minute), } + u := NewUpdater(func(ctx context.Context) (*Credentials, error) { + atomic.AddInt32(&callCount, 1) + return &fakeCred, nil + }, UpdaterOptions{ + ExpiryWindow: 0, + RefreshPeriod: 0, + Logger: TLogger{t: t}, + }) + + u.Stop(context.TODO()) + u.Stop(context.TODO()) + u.Stop(context.TODO()) }