diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index ad58de3d..5999a874 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -16,7 +16,7 @@ jobs:
contents: write
steps:
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Go
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 309fe47f..7cb7875a 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -9,7 +9,7 @@ jobs:
name: lint
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
diff --git a/.github/workflows/website.yaml b/.github/workflows/website.yaml
index 6cf946b9..a1580e63 100644
--- a/.github/workflows/website.yaml
+++ b/.github/workflows/website.yaml
@@ -19,7 +19,7 @@ jobs:
run:
working-directory: website
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Setup Node
uses: actions/setup-node@v3
diff --git a/.gitignore b/.gitignore
index 6e3d0e2f..2034cfe8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,3 +23,6 @@ kubeconfig
/examples/rrsa/kaniko-in-ack/deploy.yaml
/ci/ossutil/ossutil
/cputil
+.terraform/
+.terraform.*
+terraform.tfstate*
diff --git a/Makefile b/Makefile
index dfd1000d..6b14178a 100644
--- a/Makefile
+++ b/Makefile
@@ -17,7 +17,8 @@ build:
.PHONY: test
test:
- go test -v ./...
+ go test -race -v ./...
+ cd pkg/credentials/provider && go test -race -v ./...
.PHONY: e2e
e2e:
diff --git a/examples/rrsa/aliyuncli-demo/README.md b/examples/rrsa/aliyuncli-demo/README.md
index e307250e..87129e1f 100644
--- a/examples/rrsa/aliyuncli-demo/README.md
+++ b/examples/rrsa/aliyuncli-demo/README.md
@@ -1,6 +1,24 @@
# aliyun cli demo
-## Usage
+config.json:
+
+```
+{
+ "current": "default",
+ "profiles": [
+ {
+ "name": "default",
+ "mode": "External",
+ "region_id": "cn-hangzhou",
+ "process_command": "ack-ram-tool export-credentials --ignore-aliyun-cli-credentials --log-level=ERROR",
+ "credentials_uri": ""
+ }
+ ],
+ "meta_path": ""
+}
+```
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/aliyunlogcli-demo/README.md b/examples/rrsa/aliyunlogcli-demo/README.md
index c85bcae9..311aea05 100644
--- a/examples/rrsa/aliyunlogcli-demo/README.md
+++ b/examples/rrsa/aliyunlogcli-demo/README.md
@@ -1,6 +1,11 @@
# aliyunlog cli demo
-## Usage
+```
+ack-ram-tool export-credentials --format=environment-variables -- \
+ aliyunlog log list_project --region-endpoint=cn-hangzhou.log.aliyuncs.com
+```
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/cpp-demo/cpp-sdk/README.md b/examples/rrsa/cpp-demo/cpp-sdk/README.md
index 97ef627f..5dbdd498 100644
--- a/examples/rrsa/cpp-demo/cpp-sdk/README.md
+++ b/examples/rrsa/cpp-demo/cpp-sdk/README.md
@@ -1,6 +1,6 @@
# cpp-sdk
-## Usage
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/go-sdk/README.md b/examples/rrsa/go-sdk/README.md
index ef44d7e4..4d9a8586 100644
--- a/examples/rrsa/go-sdk/README.md
+++ b/examples/rrsa/go-sdk/README.md
@@ -2,7 +2,14 @@
Using [Alibaba Could Go SDK](https://github.com/aliyun/alibabacloud-go-sdk) with RRSA Auth.
-## Usage
+```
+go get github.com/aliyun/credentials-go@v1.2.7
+```
+
+https://github.com/aliyun/credentials-go
+
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/java-sdk/README.md b/examples/rrsa/java-sdk/README.md
index 7dc54b5d..2574a004 100644
--- a/examples/rrsa/java-sdk/README.md
+++ b/examples/rrsa/java-sdk/README.md
@@ -2,8 +2,18 @@
Using [Alibaba Could Java SDK](https://github.com/aliyun/alibabacloud-java-sdk) with RRSA Auth.
+```
+
+ com.aliyun
+ credentials-java
+ 0.2.12
+
+```
+
+https://github.com/aliyun/credentials-java
+
-## Usage
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/kaniko-in-ack/README.md b/examples/rrsa/kaniko-in-ack/README.md
index c2087109..7b6ec34d 100644
--- a/examples/rrsa/kaniko-in-ack/README.md
+++ b/examples/rrsa/kaniko-in-ack/README.md
@@ -5,7 +5,7 @@ Running kaniko in ACK:
* build image with kaniko
* push image to the ACR with RRSA Auth
-## Usage
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/log-go-sdk/README.md b/examples/rrsa/log-go-sdk/README.md
index 0afccb83..642e47d2 100644
--- a/examples/rrsa/log-go-sdk/README.md
+++ b/examples/rrsa/log-go-sdk/README.md
@@ -2,7 +2,8 @@
Using [aliyun-log-go-sdk](https://github.com/aliyun/aliyun-log-go-sdk) with RRSA Auth.
-## Usage
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/log-go-sdk/go.mod b/examples/rrsa/log-go-sdk/go.mod
index d6157ad5..7871d22f 100644
--- a/examples/rrsa/log-go-sdk/go.mod
+++ b/examples/rrsa/log-go-sdk/go.mod
@@ -3,7 +3,7 @@ module github.com/AliyunContainerService/ack-ram-tool/examples/rrsa/log-go-sdk
go 1.16
require (
- github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.7.1
+ github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.9.0
github.com/aliyun/aliyun-log-go-sdk v0.1.54
github.com/stretchr/testify v1.5.1 // indirect
golang.org/x/net v0.7.0 // indirect
diff --git a/examples/rrsa/log-go-sdk/go.sum b/examples/rrsa/log-go-sdk/go.sum
index 4952b96a..2c3462ac 100644
--- a/examples/rrsa/log-go-sdk/go.sum
+++ b/examples/rrsa/log-go-sdk/go.sum
@@ -1,7 +1,7 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.7.1 h1:PXpSLU9ghgbUvDgRSr2N+SPHV5Ze0dYoqwGM4LSyfc4=
-github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.7.1/go.mod h1:ULtI7L9xkNeJ07YNqSeT5EhjQAl1CpTgPcUn4KoNcuc=
+github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.9.0 h1:PqNYfVOnnbTN9d2X8Hg8JCcp7H53YwMWFA6//AYDAg0=
+github.com/AliyunContainerService/ack-ram-tool/pkg/credentials/provider v0.9.0/go.mod h1:ULtI7L9xkNeJ07YNqSeT5EhjQAl1CpTgPcUn4KoNcuc=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
diff --git a/examples/rrsa/log-java-sdk/README.md b/examples/rrsa/log-java-sdk/README.md
index 26e12a2b..18255807 100644
--- a/examples/rrsa/log-java-sdk/README.md
+++ b/examples/rrsa/log-java-sdk/README.md
@@ -2,7 +2,18 @@
Using [aliyun-log-java-sdk](https://github.com/aliyun/aliyun-log-java-sdk) with RRSA Auth.
-## Usage
+```
+
+ com.aliyun
+ credentials-java
+ 0.2.12
+
+```
+
+https://github.com/aliyun/credentials-java
+
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/nodejs-sdk/README.md b/examples/rrsa/nodejs-sdk/README.md
index acbe774b..033e90f2 100644
--- a/examples/rrsa/nodejs-sdk/README.md
+++ b/examples/rrsa/nodejs-sdk/README.md
@@ -2,7 +2,14 @@
Using [Alibaba Could Node.js/TypeScript SDK](https://github.com/aliyun/alibabacloud-typescript-sdk) with RRSA Auth.
-## Usage
+```
+npm install @alicloud/credentials
+```
+
+https://github.com/aliyun/credentials-nodejs
+
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/oss-go-sdk/README.md b/examples/rrsa/oss-go-sdk/README.md
index 7c853b33..b2b306dd 100644
--- a/examples/rrsa/oss-go-sdk/README.md
+++ b/examples/rrsa/oss-go-sdk/README.md
@@ -2,7 +2,14 @@
Using [aliyun-oss-go-sdk](https://github.com/aliyun/aliyun-oss-go-sdk) with RRSA Auth.
-## Usage
+```
+go get github.com/aliyun/credentials-go@v1.2.7
+```
+
+https://github.com/aliyun/credentials-go
+
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/oss-java-sdk/README.md b/examples/rrsa/oss-java-sdk/README.md
index 0fbb41b4..8f3ca1b3 100644
--- a/examples/rrsa/oss-java-sdk/README.md
+++ b/examples/rrsa/oss-java-sdk/README.md
@@ -2,7 +2,18 @@
Using [aliyun-oss-java-sdk](https://github.com/aliyun/aliyun-oss-java-sdk) with RRSA Auth.
-## Usage
+```
+
+ com.aliyun
+ credentials-java
+ 0.2.12
+
+```
+
+https://github.com/aliyun/credentials-java
+
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/python3-sdk/README.md b/examples/rrsa/python3-sdk/README.md
index 1f46c2d4..808a471f 100644
--- a/examples/rrsa/python3-sdk/README.md
+++ b/examples/rrsa/python3-sdk/README.md
@@ -2,7 +2,14 @@
Using [Alibaba Could Python 3 SDK](https://github.com/aliyun/alibabacloud-python-sdk) with RRSA Auth.
-## Usage
+```
+pip install alibabacloud_credentials>=0.3.1
+```
+
+https://github.com/aliyun/credentials-python
+
+
+## Demo
1. Enable RRSA:
diff --git a/examples/rrsa/terraform-demo/rrsa-config/README.md b/examples/rrsa/terraform-demo/rrsa-config/README.md
new file mode 100644
index 00000000..4fa0f7ce
--- /dev/null
+++ b/examples/rrsa/terraform-demo/rrsa-config/README.md
@@ -0,0 +1,22 @@
+# RRSA Configuration Via Terraform
+
+
+```
+aliyun/terraform-provider-alicloud > v1.171.0
+```
+
+https://registry.terraform.io/providers/aliyun/alicloud/latest
+
+
+
+## Demo
+
+```
+export ALICLOUD_ACCESS_KEY=
+export ALICLOUD_SECRET_KEY=
+export ALICLOUD_REGION="cn-hangzhou"
+
+terraform init
+terraform plan
+terraform apply
+```
diff --git a/examples/rrsa/terraform-demo/rrsa-config/main.tf b/examples/rrsa/terraform-demo/rrsa-config/main.tf
new file mode 100644
index 00000000..03525b43
--- /dev/null
+++ b/examples/rrsa/terraform-demo/rrsa-config/main.tf
@@ -0,0 +1,121 @@
+provider "alicloud" {
+}
+
+variable "k8s_name_prefix" {
+ description = "The name prefix used to create ASK cluster."
+ default = "ask-rrsa-example"
+}
+
+resource "random_uuid" "this" {}
+
+
+locals {
+ k8s_name_ask = substr(join("-", [var.k8s_name_prefix,"ask"]), 0, 63)
+ new_vpc_name = "tf-vpc-172-16"
+ new_vsw_name = "tf-vswitch-172-16-0"
+}
+
+data "alicloud_zones" "default" {
+ available_resource_creation = "VSwitch"
+}
+
+resource "alicloud_vpc" "vpc" {
+ vpc_name = local.new_vpc_name
+ cidr_block = "172.16.0.0/12"
+}
+
+resource "alicloud_vswitch" "vsw" {
+ vswitch_name = local.new_vsw_name
+ vpc_id = alicloud_vpc.vpc.id
+ cidr_block = cidrsubnet(alicloud_vpc.vpc.cidr_block, 8, 8)
+ zone_id = data.alicloud_zones.default.zones[0].id
+}
+
+
+resource "alicloud_cs_serverless_kubernetes" "serverless" {
+ name = local.k8s_name_ask
+ version = "1.26.3-aliyun.1"
+ cluster_spec = "ack.pro.small"
+ vpc_id = alicloud_vpc.vpc.id
+ vswitch_ids = split(",", join(",", alicloud_vswitch.vsw.*.id))
+ new_nat_gateway = false
+ endpoint_public_access_enabled = false
+ deletion_protection = false
+ load_balancer_spec = "slb.s2.small"
+ time_zone = "Asia/Shanghai"
+ service_cidr = "10.13.0.0/16"
+ service_discovery_types = ["CoreDNS"]
+
+ # Enable RRSA
+ enable_rrsa = true
+}
+
+
+# k8s service account info
+variable "k8s_namespace" {
+ default = "test-rrsa-ns"
+}
+variable "k8s_service_account" {
+ default = "foo-bar-manager-sa"
+}
+
+# Create a new RAM Role.
+resource "alicloud_ram_role" "role" {
+ name = "rrsa-demo-${alicloud_cs_serverless_kubernetes.serverless.id}"
+ document = <= c.runtimeSwitchCacheDuration {
+ c.logger().Debug(fmt.Sprintf("%s trigger select provider again", c.logPrefix))
+ return nil
+ }
+
+ return p
+}
+
+func (c *ChainProvider) setCurrentProvider(p CredentialsProvider) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ prePT := fmt.Sprintf("%T", c.currentProvider)
+ pT := fmt.Sprintf("%T", p)
+ if prePT != pT {
+ c.logger().Info(fmt.Sprintf("%s switch to using new provider: %s -> %s", c.logPrefix, prePT, pT))
+ }
+
+ c.lastSelectProviderTime = time.Now()
+ c.currentProvider = p
}
func (c *ChainProvider) logger() Logger {
@@ -59,31 +148,45 @@ func (c *ChainProvider) logger() Logger {
}
type DefaultChainProviderOptions struct {
+ EnableRuntimeSwitch bool
+ RuntimeSwitchCacheDuration time.Duration
+
STSEndpoint string
ExpiryWindow time.Duration
RefreshPeriod time.Duration
Logger Logger
+
+ logPrefix string
}
func NewDefaultChainProvider(opts DefaultChainProviderOptions) *ChainProvider {
- p := NewChainProvider(
- NewEnvProvider(EnvProviderOptions{}),
- NewOIDCProvider(OIDCProviderOptions{
- STSEndpoint: opts.STSEndpoint,
- ExpiryWindow: opts.ExpiryWindow,
- RefreshPeriod: opts.RefreshPeriod,
- Logger: opts.Logger,
- }),
- NewEncryptedFileProvider(EncryptedFileProviderOptions{
- ExpiryWindow: opts.ExpiryWindow,
- RefreshPeriod: opts.RefreshPeriod,
- Logger: opts.Logger,
- }),
- NewECSMetadataProvider(ECSMetadataProviderOptions{
- ExpiryWindow: opts.ExpiryWindow,
- RefreshPeriod: opts.RefreshPeriod,
- Logger: opts.Logger,
- }),
+ opts.applyDefaults()
+
+ p := NewChainProviderWithOptions(
+ []CredentialsProvider{
+ NewEnvProvider(EnvProviderOptions{}),
+ NewOIDCProvider(OIDCProviderOptions{
+ STSEndpoint: opts.STSEndpoint,
+ ExpiryWindow: opts.ExpiryWindow,
+ RefreshPeriod: opts.RefreshPeriod,
+ Logger: opts.Logger,
+ }),
+ NewEncryptedFileProvider(EncryptedFileProviderOptions{
+ ExpiryWindow: opts.ExpiryWindow,
+ RefreshPeriod: opts.RefreshPeriod,
+ Logger: opts.Logger,
+ }),
+ NewECSMetadataProvider(ECSMetadataProviderOptions{
+ ExpiryWindow: opts.ExpiryWindow,
+ RefreshPeriod: opts.RefreshPeriod,
+ Logger: opts.Logger,
+ }),
+ },
+ ChainProviderOptions{
+ EnableRuntimeSwitch: opts.EnableRuntimeSwitch,
+ RuntimeSwitchCacheDuration: opts.RuntimeSwitchCacheDuration,
+ logPrefix: opts.logPrefix,
+ },
)
p.Logger = opts.Logger
return p
@@ -100,3 +203,18 @@ func DefaultChainProviderWithLogger(l Logger) *ChainProvider {
Logger: l,
})
}
+
+func (o *ChainProviderOptions) applyDefaults() {
+ if o.RuntimeSwitchCacheDuration <= 0 {
+ o.RuntimeSwitchCacheDuration = defaultRuntimeSwitchCacheDuration
+ }
+ if o.logPrefix == "" {
+ o.logPrefix = "[ChainProvider]"
+ }
+}
+
+func (o *DefaultChainProviderOptions) applyDefaults() {
+ if o.logPrefix == "" {
+ o.logPrefix = "[DefaultChainProvider]"
+ }
+}
diff --git a/pkg/credentials/provider/chain_provider_test.go b/pkg/credentials/provider/chain_provider_test.go
new file mode 100644
index 00000000..8ef13009
--- /dev/null
+++ b/pkg/credentials/provider/chain_provider_test.go
@@ -0,0 +1,88 @@
+package provider
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestChainProvider_Credentials_success(t *testing.T) {
+ p1 := NewAccessKeyProvider("", "")
+ p2 := NewAccessKeyProvider("", "")
+ p3 := NewSTSTokenProvider("ak3", "sk3", "sts3")
+ cp := NewChainProvider(p1, p2, p3)
+
+ cred, err := cp.Credentials(context.TODO())
+ if err != nil {
+ t.Errorf("should no error: %+v", err)
+ }
+ if cred.AccessKeyId != "ak3" ||
+ cred.AccessKeySecret != "sk3" ||
+ cred.SecurityToken != "sts3" {
+ t.Errorf("unexpect ret: %+v", *cred)
+ }
+}
+
+func TestChainProvider_Credentials_no_provider(t *testing.T) {
+ cp := NewChainProvider(NewAccessKeyProvider("", ""))
+ cred, err := cp.Credentials(context.TODO())
+ if err == nil {
+ t.Errorf("should return error: %+v", err)
+ }
+ t.Log(err)
+ if cred != nil {
+ t.Errorf("should return nil: %+v", *cred)
+ }
+}
+
+func TestChainProvider_Stop(t *testing.T) {
+ var callCount int32
+ s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ fmt.Fprint(w, `
+{}
+`)
+ })
+
+ p := NewRoleArnProvider(
+ NewAccessKeyProvider("ak1", "sk1"),
+ "role_arn",
+ RoleArnProviderOptions{
+ STSEndpoint: s.URL,
+ RefreshPeriod: time.Millisecond * 100,
+ Logger: TLogger{t: t},
+ },
+ )
+
+ cp := NewChainProvider(p)
+ cp.Logger = TLogger{t: t}
+ cp.Credentials(context.TODO())
+
+ cv := atomic.LoadInt32(&callCount)
+ if cv < 1 {
+ t.Errorf("callCount should >= 1: %v", cv)
+ }
+
+ time.Sleep(time.Second)
+ cv = atomic.LoadInt32(&callCount)
+ if cv <= 1 {
+ t.Errorf("callCount should > 1: %v", cv)
+ }
+
+ cp.Stop(context.TODO())
+ time.Sleep(time.Second)
+ curr := atomic.LoadInt32(&callCount)
+ time.Sleep(time.Second)
+
+ cv = atomic.LoadInt32(&callCount)
+ if cv != curr {
+ t.Errorf("callCount should == %v: %v", curr, cv)
+ }
+
+ cp.Stop(context.TODO())
+ cp.Stop(context.TODO())
+ cp.Stop(context.TODO())
+}
diff --git a/pkg/credentials/provider/ecsmetadata_provider.go b/pkg/credentials/provider/ecsmetadata_provider.go
index 33263826..3acef2e1 100644
--- a/pkg/credentials/provider/ecsmetadata_provider.go
+++ b/pkg/credentials/provider/ecsmetadata_provider.go
@@ -72,6 +72,10 @@ func (e *ECSMetadataProvider) Credentials(ctx context.Context) (*Credentials, er
return e.u.Credentials(ctx)
}
+func (e *ECSMetadataProvider) Stop(ctx context.Context) {
+ e.u.Stop(ctx)
+}
+
type ecsMetadataStsResponse struct {
AccessKeyId string `json:"AccessKeyId"`
AccessKeySecret string `json:"AccessKeySecret"`
@@ -136,6 +140,7 @@ func (e *ECSMetadataProvider) getMedataToken(ctx context.Context) (string, error
return e.metadataToken, nil
}
+ e.logger().Debug("start to get metadata token")
h := http.Header{}
h.Set("X-aliyun-ecs-metadata-token-ttl-seconds", fmt.Sprintf("%d", e.metadataTokenTTLSeconds))
body, err := e.getMedataData(ctx, http.MethodPut, "/latest/api/token", h)
diff --git a/pkg/credentials/provider/ecsmetadata_provider_test.go b/pkg/credentials/provider/ecsmetadata_provider_test.go
new file mode 100644
index 00000000..f0dec92b
--- /dev/null
+++ b/pkg/credentials/provider/ecsmetadata_provider_test.go
@@ -0,0 +1,49 @@
+package provider
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "testing"
+)
+
+func TestECSMetadataProvider_Credentials_success(t *testing.T) {
+ s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/latest/api/token":
+ {
+ fmt.Fprint(w, "token-xxx")
+ return
+ }
+ default:
+ {
+ fmt.Fprint(w, `
+{
+ "AccessKeyId": "ak",
+ "AccessKeySecret": "sk",
+ "SecurityToken": "tt",
+ "Expiration": "2206-01-02T15:04:05Z"
+}
+`)
+ }
+ }
+ })
+
+ p := NewECSMetadataProvider(ECSMetadataProviderOptions{
+ Endpoint: s.URL,
+ RoleName: "test",
+ })
+
+ cred, err := p.Credentials(context.TODO())
+ if err != nil {
+ t.Log(err)
+ t.Errorf("should no error: %+v", err)
+ }
+
+ if cred.AccessKeyId != "ak" ||
+ cred.AccessKeySecret != "sk" ||
+ cred.SecurityToken != "tt" ||
+ cred.Expiration.IsZero() {
+ t.Errorf("got unexpected cred")
+ }
+}
diff --git a/pkg/credentials/provider/env_provider.go b/pkg/credentials/provider/env_provider.go
index 58151b66..e05fa768 100644
--- a/pkg/credentials/provider/env_provider.go
+++ b/pkg/credentials/provider/env_provider.go
@@ -13,41 +13,58 @@ const (
)
type EnvProvider struct {
- cred *Credentials
-
- envAccessKeyId string
- envAccessKeySecret string
- envSecurityToken string
+ cp *ChainProvider
}
type EnvProviderOptions struct {
EnvAccessKeyId string
EnvAccessKeySecret string
EnvSecurityToken string
+
+ EnvRoleArn string
+ EnvOIDCProviderArn string
+ EnvOIDCTokenFile string
}
func NewEnvProvider(opts EnvProviderOptions) *EnvProvider {
opts.applyDefaults()
- return &EnvProvider{
- cred: &Credentials{
- AccessKeyId: os.Getenv(opts.EnvAccessKeyId),
- AccessKeySecret: os.Getenv(opts.EnvAccessKeySecret),
- SecurityToken: os.Getenv(opts.EnvSecurityToken),
- },
- envAccessKeyId: opts.EnvAccessKeyId,
- envAccessKeySecret: opts.EnvAccessKeySecret,
- envSecurityToken: opts.EnvSecurityToken,
- }
+ e := &EnvProvider{}
+ e.cp = e.getProvider(opts)
+
+ return e
}
func (e *EnvProvider) Credentials(ctx context.Context) (*Credentials, error) {
- if e.cred.AccessKeyId == "" || e.cred.AccessKeySecret == "" {
- return nil, NewNotEnableError(fmt.Errorf("env %s or %s is empty",
- e.envAccessKeyId, e.envAccessKeySecret))
+ cred, err := e.cp.Credentials(ctx)
+
+ if err != nil {
+ if IsNoAvailableProviderError(err) {
+ return nil, NewNotEnableError(fmt.Errorf("not found credentials from env: %w", err))
+ }
+ return nil, err
}
- return e.cred.DeepCopy(), nil
+ return cred.DeepCopy(), nil
+}
+
+func (e *EnvProvider) getProvider(opts EnvProviderOptions) *ChainProvider {
+ p1 := NewSTSTokenProvider(
+ os.Getenv(opts.EnvAccessKeyId),
+ os.Getenv(opts.EnvAccessKeySecret),
+ os.Getenv(opts.EnvSecurityToken),
+ )
+ p2 := NewOIDCProvider(OIDCProviderOptions{
+ RoleArn: os.Getenv(opts.EnvRoleArn),
+ OIDCProviderArn: os.Getenv(opts.EnvOIDCProviderArn),
+ OIDCTokenFile: os.Getenv(opts.EnvOIDCTokenFile),
+ })
+ p3 := NewAccessKeyProvider(
+ os.Getenv(opts.EnvAccessKeyId),
+ os.Getenv(opts.EnvAccessKeySecret),
+ )
+ cp := NewChainProvider(p1, p2, p3)
+ return cp
}
func (o *EnvProviderOptions) applyDefaults() {
@@ -60,4 +77,14 @@ func (o *EnvProviderOptions) applyDefaults() {
if o.EnvSecurityToken == "" {
o.EnvSecurityToken = envSecurityToken
}
+
+ if o.EnvRoleArn == "" {
+ o.EnvRoleArn = defaultEnvRoleArn
+ }
+ if o.EnvOIDCProviderArn == "" {
+ o.EnvOIDCProviderArn = defaultEnvOIDCProviderArn
+ }
+ if o.EnvOIDCTokenFile == "" {
+ o.EnvOIDCTokenFile = defaultEnvOIDCTokenFile
+ }
}
diff --git a/pkg/credentials/provider/env_provider_test.go b/pkg/credentials/provider/env_provider_test.go
new file mode 100644
index 00000000..697267c3
--- /dev/null
+++ b/pkg/credentials/provider/env_provider_test.go
@@ -0,0 +1,80 @@
+package provider
+
+import (
+ "context"
+ "os"
+ "testing"
+)
+
+func TestEnvProvider_Credentials(t *testing.T) {
+ envAk := "TestEnvProvider_Credentials_AK"
+ envSK := "TestEnvProvider_Credentials_SK"
+ envToken := "TestEnvProvider_Credentials_Token"
+ envRoleArn := "TestEnvProvider_Credentials_Role_ARN"
+ envOidcP := "TestEnvProvider_Credentials_OIDC_Pro"
+ envOidcT := "TestEnvProvider_Credentials_OIDC_Token"
+
+ t.Run("no env", func(t *testing.T) {
+ p := NewEnvProvider(EnvProviderOptions{
+ EnvAccessKeyId: envAk,
+ EnvAccessKeySecret: envSK,
+ EnvSecurityToken: envToken,
+ EnvRoleArn: envRoleArn,
+ EnvOIDCProviderArn: envOidcP,
+ EnvOIDCTokenFile: envOidcT,
+ })
+ cred, err := p.Credentials(context.TODO())
+ if err == nil {
+ t.Errorf("should return error: %+v", err)
+ }
+ t.Log(err)
+ if cred != nil {
+ t.Errorf("got unexpected cred: %+v", *cred)
+ }
+ })
+
+ t.Run("only ak env", func(t *testing.T) {
+ os.Setenv(envAk, "ak")
+ os.Setenv(envSK, "sk")
+ p := NewEnvProvider(EnvProviderOptions{
+ EnvAccessKeyId: envAk,
+ EnvAccessKeySecret: envSK,
+ EnvSecurityToken: envToken,
+ EnvRoleArn: envRoleArn,
+ EnvOIDCProviderArn: envOidcP,
+ EnvOIDCTokenFile: envOidcT,
+ })
+ cred, err := p.Credentials(context.TODO())
+ if err != nil {
+ t.Errorf("should no error: %+v", err)
+ }
+ if cred.AccessKeyId != "ak" ||
+ cred.AccessKeySecret != "sk" ||
+ cred.SecurityToken != "" {
+ t.Errorf("got unexpected cred: %+v", *cred)
+ }
+ })
+
+ t.Run("sts token env", func(t *testing.T) {
+ os.Setenv(envAk, "ak")
+ os.Setenv(envSK, "sk")
+ os.Setenv(envToken, "sts-token")
+ p := NewEnvProvider(EnvProviderOptions{
+ EnvAccessKeyId: envAk,
+ EnvAccessKeySecret: envSK,
+ EnvSecurityToken: envToken,
+ EnvRoleArn: envRoleArn,
+ EnvOIDCProviderArn: envOidcP,
+ EnvOIDCTokenFile: envOidcT,
+ })
+ cred, err := p.Credentials(context.TODO())
+ if err != nil {
+ t.Errorf("should no error: %+v", err)
+ }
+ if cred.AccessKeyId != "ak" ||
+ cred.AccessKeySecret != "sk" ||
+ cred.SecurityToken != "sts-token" {
+ t.Errorf("got unexpected cred: %+v", *cred)
+ }
+ })
+}
diff --git a/pkg/credentials/provider/error.go b/pkg/credentials/provider/error.go
new file mode 100644
index 00000000..5343c255
--- /dev/null
+++ b/pkg/credentials/provider/error.go
@@ -0,0 +1,45 @@
+package provider
+
+type NotEnableError struct {
+ err error
+}
+
+type NoAvailableProviderError struct {
+ err error
+}
+
+func NewNotEnableError(err error) *NotEnableError {
+ return &NotEnableError{err: err}
+}
+
+func NewNoAvailableProviderError(err error) *NoAvailableProviderError {
+ return &NoAvailableProviderError{err: err}
+}
+
+func (e NotEnableError) Error() string {
+ return e.err.Error()
+}
+
+func (e NoAvailableProviderError) Error() string {
+ return e.err.Error()
+}
+
+func IsNotEnableError(err error) bool {
+ if _, ok := err.(*NotEnableError); ok {
+ return true
+ }
+ if _, ok := err.(NotEnableError); ok {
+ return true
+ }
+ return false
+}
+
+func IsNoAvailableProviderError(err error) bool {
+ if _, ok := err.(*NoAvailableProviderError); ok {
+ return true
+ }
+ if _, ok := err.(NoAvailableProviderError); ok {
+ return true
+ }
+ return false
+}
diff --git a/pkg/credentials/provider/file_provider.go b/pkg/credentials/provider/file_provider.go
index 231b6d3a..5e8eebde 100644
--- a/pkg/credentials/provider/file_provider.go
+++ b/pkg/credentials/provider/file_provider.go
@@ -43,6 +43,10 @@ func (f *FileProvider) Credentials(ctx context.Context) (*Credentials, error) {
return f.u.Credentials(ctx)
}
+func (f *FileProvider) Stop(ctx context.Context) {
+ f.u.Stop(ctx)
+}
+
func (f *FileProvider) getCredentials(ctx context.Context) (*Credentials, error) {
data, err := os.ReadFile(f.filepath)
if err != nil {
diff --git a/pkg/credentials/provider/file_provider_test.go b/pkg/credentials/provider/file_provider_test.go
new file mode 100644
index 00000000..bf1dad12
--- /dev/null
+++ b/pkg/credentials/provider/file_provider_test.go
@@ -0,0 +1,39 @@
+package provider
+
+import (
+ "context"
+ "os"
+ "path"
+ "testing"
+ "time"
+)
+
+func TestFileProvider_Credentials(t *testing.T) {
+ d, err := os.MkdirTemp("", "TestFileProvider_Credentials")
+ if err != nil {
+ t.Errorf("should not error: %+v", err)
+ return
+ }
+ fp := path.Join(d, "test.json")
+ os.WriteFile(fp, []byte("abc"), 0600)
+
+ f := NewFileProvider(fp, func(data []byte) (*Credentials, error) {
+ return &Credentials{
+ AccessKeyId: "ak_TestFileProvider_Credentials",
+ AccessKeySecret: "sk_TestFileProvider_Credentials",
+ SecurityToken: "",
+ Expiration: time.Time{},
+ }, nil
+ }, FileProviderOptions{})
+
+ cred, err := f.Credentials(context.TODO())
+ if err != nil {
+ t.Errorf("should not error: %+v", err)
+ return
+ }
+ if cred.AccessKeyId != "ak_TestFileProvider_Credentials" ||
+ cred.AccessKeySecret != "sk_TestFileProvider_Credentials" {
+ t.Errorf("unexpected case found: %+v", *cred)
+ return
+ }
+}
diff --git a/pkg/credentials/provider/function_provider_test.go b/pkg/credentials/provider/function_provider_test.go
new file mode 100644
index 00000000..f75bf4bf
--- /dev/null
+++ b/pkg/credentials/provider/function_provider_test.go
@@ -0,0 +1,29 @@
+package provider
+
+import (
+ "context"
+ "testing"
+ "time"
+)
+
+func TestFunctionProvider_Credentials(t *testing.T) {
+ f := NewFunctionProvider(func(ctx context.Context) (*Credentials, error) {
+ return &Credentials{
+ AccessKeyId: "ak_TestFunctionProvider_Credentials",
+ AccessKeySecret: "sk_TestFunctionProvider_Credentials",
+ SecurityToken: "",
+ Expiration: time.Time{},
+ }, nil
+ })
+
+ cred, err := f.Credentials(context.TODO())
+ if err != nil {
+ t.Errorf("should not error: %+v", err)
+ return
+ }
+ if cred.AccessKeyId != "ak_TestFunctionProvider_Credentials" ||
+ cred.AccessKeySecret != "sk_TestFunctionProvider_Credentials" {
+ t.Errorf("unexpected case found: %+v", *cred)
+ return
+ }
+}
diff --git a/pkg/credentials/provider/oidc_provider.go b/pkg/credentials/provider/oidc_provider.go
index 73800685..531cce0e 100644
--- a/pkg/credentials/provider/oidc_provider.go
+++ b/pkg/credentials/provider/oidc_provider.go
@@ -19,6 +19,8 @@ const (
defaultEnvRoleArn = "ALIBABA_CLOUD_ROLE_ARN"
defaultEnvOIDCProviderArn = "ALIBABA_CLOUD_OIDC_PROVIDER_ARN"
defaultEnvOIDCTokenFile = "ALIBABA_CLOUD_OIDC_TOKEN_FILE"
+
+ defaultExpiryWindowForAssumeRole = time.Minute * 10
)
var defaultSessionName = "default-session-name"
@@ -98,12 +100,16 @@ func (o *OIDCProvider) Credentials(ctx context.Context) (*Credentials, error) {
return o.u.Credentials(ctx)
}
+func (o *OIDCProvider) Stop(ctx context.Context) {
+ o.u.Stop(ctx)
+}
+
func (o *OIDCProvider) getCredentials(ctx context.Context) (*Credentials, error) {
roleArn := o.roleArn
oidcProviderArn := o.oidcProviderArn
tokenFile := o.oidcTokenFile
if roleArn == "" || oidcProviderArn == "" || tokenFile == "" {
- return nil, NewNotEnableError(errors.New("roleArn, oidcProviderArn or tokenFile is empty"))
+ return nil, NewNotEnableError(errors.New("roleArn, oidcProviderArn or oidcTokenFile is empty"))
}
tokenData, err := os.ReadFile(tokenFile)
@@ -235,7 +241,7 @@ func (o *OIDCProviderOptions) applyDefaults() {
o.SessionName = defaultSessionName
}
if o.ExpiryWindow == 0 {
- o.ExpiryWindow = defaultExpiryWindow
+ o.ExpiryWindow = defaultExpiryWindowForAssumeRole
}
if o.EnvRoleArn == "" {
o.EnvRoleArn = defaultEnvRoleArn
diff --git a/pkg/credentials/provider/oidc_provider_test.go b/pkg/credentials/provider/oidc_provider_test.go
new file mode 100644
index 00000000..99978ee8
--- /dev/null
+++ b/pkg/credentials/provider/oidc_provider_test.go
@@ -0,0 +1,44 @@
+package provider
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "os"
+ "testing"
+)
+
+func TestOIDCProvider_Credentials_success(t *testing.T) {
+ s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, `
+{
+ "Credentials": {
+ "AccessKeyId": "ak",
+ "AccessKeySecret": "sk",
+ "SecurityToken": "tt",
+ "Expiration": "2206-01-02T15:04:05Z"
+ }
+}
+`)
+ })
+
+ p := NewOIDCProvider(OIDCProviderOptions{
+ STSEndpoint: s.URL,
+ RoleArn: "role_arn",
+ OIDCProviderArn: "oidc_arn",
+ OIDCTokenFile: os.Args[0],
+ })
+
+ cred, err := p.Credentials(context.TODO())
+ if err != nil {
+ t.Log(err)
+ t.Errorf("should no error: %+v", err)
+ }
+
+ if cred.AccessKeyId != "ak" ||
+ cred.AccessKeySecret != "sk" ||
+ cred.SecurityToken != "tt" ||
+ cred.Expiration.IsZero() {
+ t.Errorf("got unexpected cred")
+ }
+}
diff --git a/pkg/credentials/provider/provider.go b/pkg/credentials/provider/provider.go
index 80651ad7..11b9289d 100644
--- a/pkg/credentials/provider/provider.go
+++ b/pkg/credentials/provider/provider.go
@@ -14,18 +14,11 @@ type CredentialsProvider interface {
Credentials(ctx context.Context) (*Credentials, error)
}
+type Stopper interface {
+ Stop(ctx context.Context)
+}
+
func init() {
name := path.Base(os.Args[0])
UserAgent = fmt.Sprintf("%s %s/%s ack-ram-tool/%s", name, runtime.GOOS, runtime.GOARCH, runtime.Version())
}
-
-type NotEnableError struct {
- err error
-}
-
-func NewNotEnableError(err error) *NotEnableError {
- return &NotEnableError{err: err}
-}
-func (e NotEnableError) Error() string {
- return fmt.Sprintf("this provider is not enabled: %s", e.err.Error())
-}
diff --git a/pkg/credentials/provider/rolearn_provider.go b/pkg/credentials/provider/rolearn_provider.go
index 7f607a03..0ae6ee64 100644
--- a/pkg/credentials/provider/rolearn_provider.go
+++ b/pkg/credentials/provider/rolearn_provider.go
@@ -69,6 +69,13 @@ func (r *RoleArnProvider) Credentials(ctx context.Context) (*Credentials, error)
return r.u.Credentials(ctx)
}
+func (r *RoleArnProvider) Stop(ctx context.Context) {
+ r.u.Stop(ctx)
+ if s, ok := r.cp.(Stopper); ok {
+ s.Stop(ctx)
+ }
+}
+
func (r *RoleArnProvider) getCredentials(ctx context.Context) (*Credentials, error) {
return r.assumeRole(ctx, r.roleArn)
}
@@ -200,7 +207,7 @@ func (o *RoleArnProviderOptions) applyDefaults() {
o.SessionName = defaultSessionName
}
if o.ExpiryWindow == 0 {
- o.ExpiryWindow = defaultExpiryWindow
+ o.ExpiryWindow = defaultExpiryWindowForAssumeRole
}
if o.Logger == nil {
o.Logger = defaultLog
diff --git a/pkg/credentials/provider/rolearn_provider_test.go b/pkg/credentials/provider/rolearn_provider_test.go
new file mode 100644
index 00000000..2f55263d
--- /dev/null
+++ b/pkg/credentials/provider/rolearn_provider_test.go
@@ -0,0 +1,110 @@
+package provider
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestRoleArnProvider_Credentials_success(t *testing.T) {
+ s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, `
+{
+ "Credentials": {
+ "AccessKeyId": "ak",
+ "AccessKeySecret": "sk",
+ "SecurityToken": "tt",
+ "Expiration": "2206-01-02T15:04:05Z"
+ }
+}
+`)
+ })
+
+ p := NewRoleArnProvider(
+ NewAccessKeyProvider("ak1", "sk1"),
+ "role_arn",
+ RoleArnProviderOptions{
+ STSEndpoint: s.URL,
+ },
+ )
+
+ cred, err := p.Credentials(context.TODO())
+ if err != nil {
+ t.Log(err)
+ t.Errorf("should no error: %+v", err)
+ }
+
+ if cred.AccessKeyId != "ak" ||
+ cred.AccessKeySecret != "sk" ||
+ cred.SecurityToken != "tt" ||
+ cred.Expiration.IsZero() {
+ t.Errorf("got unexpected cred")
+ }
+}
+
+func TestRoleArnProvider_Credentials_stop_with_no_stop_method_cp(t *testing.T) {
+ s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, `
+{
+ "Credentials": {
+ "AccessKeyId": "ak",
+ "AccessKeySecret": "sk",
+ "SecurityToken": "tt",
+ "Expiration": "2206-01-02T15:04:05Z"
+ }
+}
+`)
+ })
+
+ p := NewRoleArnProvider(
+ NewAccessKeyProvider("ak1", "sk1"),
+ "role_arn",
+ RoleArnProviderOptions{
+ STSEndpoint: s.URL,
+ Logger: TLogger{t},
+ },
+ )
+
+ p.Stop(context.TODO())
+ p.Stop(context.TODO())
+}
+
+func TestRoleArnProvider_Credentials_stop_with_stop_method_cp(t *testing.T) {
+ s := setupHttpTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, `
+{
+ "Credentials": {
+ "AccessKeyId": "ak",
+ "AccessKeySecret": "sk",
+ "SecurityToken": "tt",
+ "Expiration": "2206-01-02T15:04:05Z"
+ }
+}
+`)
+ })
+
+ p := NewRoleArnProvider(
+ NewOIDCProvider(OIDCProviderOptions{
+ Logger: TLogger{t},
+ }),
+ "role_arn",
+ RoleArnProviderOptions{
+ STSEndpoint: s.URL,
+ Logger: TLogger{t},
+ },
+ )
+
+ p.Stop(context.TODO())
+ p.Stop(context.TODO())
+}
+
+func setupHttpTestServer(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ handler(w, r)
+ })
+ s := httptest.NewServer(mux)
+ return s
+}
diff --git a/pkg/credentials/provider/ststoken_provider_test.go b/pkg/credentials/provider/ststoken_provider_test.go
new file mode 100644
index 00000000..ee197a0d
--- /dev/null
+++ b/pkg/credentials/provider/ststoken_provider_test.go
@@ -0,0 +1,43 @@
+package provider
+
+import (
+ "context"
+ "testing"
+ "time"
+)
+
+func TestSTSTokenProvider_Credentials(t *testing.T) {
+ p := NewSTSTokenProvider("ak", "sk", "tt")
+ cred, err := p.Credentials(context.TODO())
+
+ if err != nil {
+ t.Log(err)
+ t.Errorf("should not return error: %+v", err)
+ }
+
+ if cred.AccessKeyId != "ak" ||
+ cred.AccessKeySecret != "sk" ||
+ cred.SecurityToken != "tt" ||
+ !cred.Expiration.IsZero() {
+ t.Error("cred value is not expected")
+ }
+}
+
+func TestSTSTokenProvider_SetExpiration(t *testing.T) {
+ p := NewSTSTokenProvider("ak", "sk", "tt")
+ tm := time.Now()
+ p.SetExpiration(tm)
+ cred, err := p.Credentials(context.TODO())
+
+ if err != nil {
+ t.Log(err)
+ t.Errorf("should not return error: %+v", err)
+ }
+
+ if cred.AccessKeyId != "ak" ||
+ cred.AccessKeySecret != "sk" ||
+ cred.SecurityToken != "tt" ||
+ cred.Expiration.IsZero() {
+ t.Error("cred value is not expected")
+ }
+}
diff --git a/pkg/credentials/provider/updater.go b/pkg/credentials/provider/updater.go
index 0a51285e..705f436d 100644
--- a/pkg/credentials/provider/updater.go
+++ b/pkg/credentials/provider/updater.go
@@ -25,6 +25,9 @@ type Updater struct {
Logger Logger
nowFunc func() time.Time
logPrefix string
+
+ doneCh chan struct{}
+ stopped bool
}
type UpdaterOptions struct {
@@ -45,6 +48,7 @@ func NewUpdater(getter getCredentialsFunc, opts UpdaterOptions) *Updater {
Logger: opts.Logger,
nowFunc: time.Now,
logPrefix: opts.LogPrefix,
+ doneCh: make(chan struct{}),
}
return u
}
@@ -57,17 +61,36 @@ func (u *Updater) Start(ctx context.Context) {
go u.startRefreshLoop(ctx)
}
+func (u *Updater) Stop(shutdownCtx context.Context) {
+ u.logger().Debug(fmt.Sprintf("%s start to stop...", u.logPrefix))
+
+ go func() {
+ u.lockForCred.Lock()
+ defer u.lockForCred.Unlock()
+ if u.stopped {
+ return
+ }
+ u.stopped = true
+ close(u.doneCh)
+ }()
+
+ select {
+ case <-shutdownCtx.Done():
+ case <-u.doneCh:
+ }
+}
+
func (u *Updater) startRefreshLoop(ctx context.Context) {
ticket := time.NewTicker(u.refreshPeriod)
defer ticket.Stop()
- u.refreshCredForLoop(ctx)
-
loop:
for {
select {
case <-ctx.Done():
break loop
+ case <-u.doneCh:
+ break loop
case <-ticket.C:
u.refreshCredForLoop(ctx)
}
@@ -101,7 +124,7 @@ func (u *Updater) refreshCredForLoop(ctx context.Context) {
if err == nil {
return
}
- if _, ok := err.(*NotEnableError); ok {
+ if IsNotEnableError(err) {
return
}
if i < maxRetry-1 {
@@ -113,7 +136,7 @@ func (u *Updater) refreshCredForLoop(ctx context.Context) {
func (u *Updater) refreshCred(ctx context.Context) error {
cred, err := u.getCredentials(ctx)
if err != nil {
- if _, ok := err.(*NotEnableError); ok {
+ if IsNotEnableError(err) {
return err
}
u.logger().Error(err, fmt.Sprintf("%s refresh credentials failed: %s", u.logPrefix, err))
@@ -151,8 +174,11 @@ func (u *Updater) Expired() bool {
func (u *Updater) expired(expiryDelta time.Duration) bool {
exp := u.expiration()
+ if expiryDelta > 0 {
+ exp = exp.Add(-expiryDelta)
+ }
- return exp.Add(-expiryDelta).Before(u.now())
+ return exp.Before(u.now())
}
func (u *Updater) expiration() time.Time {
diff --git a/pkg/credentials/provider/updater_test.go b/pkg/credentials/provider/updater_test.go
index 5802ddc7..59f0fb92 100644
--- a/pkg/credentials/provider/updater_test.go
+++ b/pkg/credentials/provider/updater_test.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
+ "sync/atomic"
"testing"
"time"
)
@@ -25,12 +26,12 @@ func (d TLogger) Error(err error, msg string) {
}
func TestUpdater_refreshCredForLoop_refresh(t *testing.T) {
- var callCount int
+ var callCount int32
fakeCred := Credentials{
Expiration: time.Now().Add(time.Minute),
}
u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
- callCount++
+ atomic.AddInt32(&callCount, 1)
return &fakeCred, nil
}, UpdaterOptions{
ExpiryWindow: 0,
@@ -39,8 +40,10 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) {
})
u.refreshCredForLoop(context.TODO())
- if callCount != 1 {
- t.Errorf("callCount should be 1 but got %d", callCount)
+
+ cv := atomic.LoadInt32(&callCount)
+ if cv != 1 {
+ t.Errorf("callCount should be 1 but got %d", cv)
}
ret := u.Expired()
if ret {
@@ -48,8 +51,9 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) {
}
u.refreshCredForLoop(context.TODO())
- if callCount != 1 {
- t.Errorf("callCount should be 1 but got %d", callCount)
+ cv = atomic.LoadInt32(&callCount)
+ if cv != 1 {
+ t.Errorf("callCount should be 1 but got %d", cv)
}
u.nowFunc = func() time.Time {
@@ -62,8 +66,9 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) {
fakeCred.Expiration = time.Now().Add(time.Minute * 5)
u.refreshCredForLoop(context.TODO())
- if callCount != 2 {
- t.Errorf("callCount should be 2 but got %d", callCount)
+ cv = atomic.LoadInt32(&callCount)
+ if cv != 2 {
+ t.Errorf("callCount should be 2 but got %d", cv)
}
ret = u.Expired()
if ret {
@@ -72,10 +77,10 @@ func TestUpdater_refreshCredForLoop_refresh(t *testing.T) {
}
func TestUpdater_refreshCredForLoop_erorr(t *testing.T) {
- var callCount int
+ var callCount int32
u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
- callCount++
+ atomic.AddInt32(&callCount, 1)
return nil, errors.New("error message")
}, UpdaterOptions{
ExpiryWindow: 0,
@@ -84,8 +89,9 @@ func TestUpdater_refreshCredForLoop_erorr(t *testing.T) {
})
u.refreshCredForLoop(context.TODO())
- if callCount != 5 {
- t.Errorf("callCount should be 5 but got %d", callCount)
+ cv := atomic.LoadInt32(&callCount)
+ if cv != 5 {
+ t.Errorf("callCount should be 5 but got %d", cv)
}
ret := u.Expired()
if !ret {
@@ -94,12 +100,12 @@ func TestUpdater_refreshCredForLoop_erorr(t *testing.T) {
}
func TestUpdater_Credentials_refresh(t *testing.T) {
- var callCount int
+ var callCount int32
fakeCred := Credentials{
Expiration: time.Now().Add(time.Minute),
}
u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
- callCount++
+ atomic.AddInt32(&callCount, 1)
return &fakeCred, nil
}, UpdaterOptions{
ExpiryWindow: 0,
@@ -107,35 +113,126 @@ func TestUpdater_Credentials_refresh(t *testing.T) {
Logger: TLogger{t: t},
})
- u.Credentials(context.TODO())
- if callCount != 1 {
- t.Errorf("callCount should be 1 but got %d", callCount)
- }
- ret := u.Expired()
- if ret {
- t.Errorf("should not expired")
- }
+ t.Run("Credentials use cache", func(t *testing.T) {
+ u.Credentials(context.TODO())
+ cv := atomic.LoadInt32(&callCount)
+ if cv != 1 {
+ t.Errorf("callCount should be 1 but got %d", cv)
+ }
+ ret := u.Expired()
+ if ret {
+ t.Errorf("should not expired")
+ }
+
+ u.Credentials(context.TODO())
+ cv = atomic.LoadInt32(&callCount)
+ if cv != 1 {
+ t.Errorf("callCount should be 1 but got %d", cv)
+ }
+ })
- u.Credentials(context.TODO())
- if callCount != 1 {
- t.Errorf("callCount should be 1 but got %d", callCount)
- }
+ t.Run("Credentials expired", func(t *testing.T) {
+ u.nowFunc = func() time.Time {
+ return time.Now().Add(time.Minute * 2)
+ }
+ ret := u.Expired()
+ if !ret {
+ t.Errorf("should expired")
+ }
+ })
- u.nowFunc = func() time.Time {
- return time.Now().Add(time.Minute)
- }
- ret = u.Expired()
- if !ret {
- t.Errorf("should expired")
- }
+ t.Run("not expire, should not refresh", func(t *testing.T) {
+ fakeCred.Expiration = time.Now().Add(time.Minute * 5)
+ u.Credentials(context.TODO())
+ cv := atomic.LoadInt32(&callCount)
+ if cv != 2 {
+ t.Errorf("callCount should be 2 but got %d", cv)
+ }
+ ret := u.Expired()
+ if ret {
+ t.Errorf("should not expired")
+ }
+ })
+}
- fakeCred.Expiration = time.Now().Add(time.Minute * 5)
- u.Credentials(context.TODO())
- if callCount != 2 {
- t.Errorf("callCount should be 2 but got %d", callCount)
+func TestUpdater_expired(t *testing.T) {
+ u := &Updater{}
+ u.setCred(&Credentials{Expiration: time.Now().Add(time.Minute)})
+
+ t.Run("expiryDelta=0", func(t *testing.T) {
+ ret := u.expired(0)
+ if ret {
+ t.Errorf("should be false")
+ }
+ })
+
+ t.Run("expiryDelta > 0", func(t *testing.T) {
+ ret := u.expired(time.Minute * 5)
+ if !ret {
+ t.Errorf("should be true")
+ }
+ })
+}
+
+func TestUpdater_stop(t *testing.T) {
+ var callCount int32
+ fakeCred := Credentials{
+ Expiration: time.Now().Add(-time.Minute),
}
- ret = u.Expired()
- if ret {
- t.Errorf("should not expired")
+ u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
+ atomic.AddInt32(&callCount, 1)
+ return &fakeCred, nil
+ }, UpdaterOptions{
+ ExpiryWindow: 0,
+ RefreshPeriod: time.Millisecond * 100,
+ Logger: TLogger{t: t},
+ })
+
+ u.Start(context.TODO())
+
+ t.Run("test-refresh", func(t *testing.T) {
+ time.Sleep(time.Second)
+ cv := atomic.LoadInt32(&callCount)
+ if cv < 1 {
+ t.Errorf("callCount should be >1 but got %d", cv)
+ }
+ })
+
+ t.Run("test-stop", func(t *testing.T) {
+ u.Stop(context.TODO())
+ time.Sleep(time.Second)
+
+ curr := atomic.LoadInt32(&callCount)
+ time.Sleep(time.Second)
+
+ cv := atomic.LoadInt32(&callCount)
+ if cv != curr {
+ t.Errorf("callCount should be %d but got %d", curr, cv)
+ }
+ })
+
+ t.Run("test-stop-multiple-times", func(t *testing.T) {
+ u.Stop(context.TODO())
+ u.Stop(context.TODO())
+ u.Stop(context.TODO())
+ })
+}
+
+func TestUpdater_stop_no_start(t *testing.T) {
+ var callCount int32
+ fakeCred := Credentials{
+ Expiration: time.Now().Add(-time.Minute),
}
+ u := NewUpdater(func(ctx context.Context) (*Credentials, error) {
+ atomic.AddInt32(&callCount, 1)
+ return &fakeCred, nil
+ }, UpdaterOptions{
+ ExpiryWindow: 0,
+ RefreshPeriod: 0,
+ Logger: TLogger{t: t},
+ })
+
+ u.Stop(context.TODO())
+ u.Stop(context.TODO())
+ u.Stop(context.TODO())
}