From 7796c0e52ee68366d34685a7b63b224eb670f044 Mon Sep 17 00:00:00 2001 From: Shraddha Bang Date: Mon, 16 Sep 2024 16:23:56 -0700 Subject: [PATCH] Adding pending tasks --- go.mod | 21 +- go.sum | 42 +-- pkg/aws/cloud.go | 43 ++- pkg/aws/metrics/collector.go | 230 +++++++------ pkg/aws/metrics/collector_test.go | 170 +++------- pkg/aws/services/elbv2.go | 47 +-- pkg/aws/throttle/condition.go | 58 ++-- pkg/aws/throttle/condition_test.go | 246 +++----------- pkg/aws/throttle/config_test.go | 513 ++++++++++++++-------------- pkg/aws/throttle/throttler.go | 148 +++++---- pkg/aws/throttle/throttler_test.go | 515 +++++++++++++---------------- 11 files changed, 896 insertions(+), 1137 deletions(-) diff --git a/go.mod b/go.mod index e07b1a8cb..85cfc0a05 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,19 @@ module sigs.k8s.io/aws-load-balancer-controller go 1.22.4 require ( - github.com/aws/aws-sdk-go-v2 v1.30.3 + github.com/aws/aws-sdk-go-v2 v1.30.5 github.com/aws/aws-sdk-go-v2/config v1.27.27 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 github.com/aws/aws-sdk-go-v2/service/acm v1.28.4 + github.com/aws/aws-sdk-go-v2/service/appmesh v1.27.7 github.com/aws/aws-sdk-go-v2/service/ec2 v1.173.0 - github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.34.0 + github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.36.0 github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi v1.23.3 + github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.31.7 github.com/aws/aws-sdk-go-v2/service/shield v1.27.3 github.com/aws/aws-sdk-go-v2/service/wafregional v1.23.3 github.com/aws/aws-sdk-go-v2/service/wafv2 v1.51.4 - github.com/aws/smithy-go v1.20.3 + github.com/aws/smithy-go v1.20.4 github.com/evanphx/json-patch v5.7.0+incompatible github.com/gavv/httpexpect/v2 v2.9.0 github.com/go-logr/logr v1.4.1 @@ -22,7 +24,7 @@ require ( github.com/onsi/ginkgo/v2 v2.17.1 github.com/onsi/gomega v1.32.0 github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.16.0 + github.com/prometheus/client_golang v1.18.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 go.uber.org/zap v1.26.0 @@ -52,8 +54,8 @@ require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect @@ -123,7 +125,6 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.17 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect - github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect @@ -140,8 +141,8 @@ require ( github.com/opencontainers/image-spec v1.1.0-rc5 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.4.0 // indirect - github.com/prometheus/common v0.44.0 // indirect + github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/rubenv/sql-migrate v1.5.2 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -169,7 +170,7 @@ require ( golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/net v0.23.0 // indirect - golang.org/x/oauth2 v0.12.0 // indirect + golang.org/x/oauth2 v0.16.0 // indirect golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/term v0.18.0 // indirect diff --git a/go.sum b/go.sum index 5de579e33..c417c0dad 100644 --- a/go.sum +++ b/go.sum @@ -36,32 +36,36 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 h1:4daAzAu0S6Vi7/lbWECcX0j45yZReDZ56BQsrVBOEEY= github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= -github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= -github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= +github.com/aws/aws-sdk-go-v2 v1.30.5 h1:mWSRTwQAb0aLE17dSzztCVJWI9+cRMgqebndjwDyK0g= +github.com/aws/aws-sdk-go-v2 v1.30.5/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 h1:pI7Bzt0BJtYA0N/JEC6B8fJ4RBrEMi1LBrkMdFYNSnQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17/go.mod h1:Dh5zzJYMtxfIjYW+/evjQ8uj2OyR/ve2KROHGHlSFqE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 h1:Mqr/V5gvrhA2gvgnF42Zh5iMiQNcOYthFYwCyrnuWlc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17/go.mod h1:aLJpZlCmjE+V+KtN1q1uyZkfnUWpQGpbsn89XPKyzfU= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= github.com/aws/aws-sdk-go-v2/service/acm v1.28.4 h1:wiW1Y6/1lysA0eJZRq0I53YYKuV9MNAzL15z2eZRlEE= github.com/aws/aws-sdk-go-v2/service/acm v1.28.4/go.mod h1:bzjymHHRhexkSMIvUHMpKydo9U82bmqQ5ru0IzYM8m8= +github.com/aws/aws-sdk-go-v2/service/appmesh v1.27.7 h1:q44a6kysAfej9zZwRnraOg9sBVIKhxKjPbqYs44Vpdk= +github.com/aws/aws-sdk-go-v2/service/appmesh v1.27.7/go.mod h1:ZYSmrgAMp0rTCHH+SGsoxZo+PPbgsDqBzewTp3tSJ60= github.com/aws/aws-sdk-go-v2/service/ec2 v1.173.0 h1:ta62lid9JkIpKZtZZXSj6rP2AqY5x1qYGq53ffxqD9Q= github.com/aws/aws-sdk-go-v2/service/ec2 v1.173.0/go.mod h1:o6QDjdVKpP5EF0dp/VlvqckzuSDATr1rLdHt3A5m0YY= -github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.34.0 h1:8rDRtPOu3ax8jEctw7G926JQlnFdhZZA4KJzQ+4ks3Q= -github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.34.0/go.mod h1:L5bVuO4PeXuDuMYZfL3IW69E6mz6PDCYpp6IKDlcLMA= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.36.0 h1:3t8g6wmPA9hr69qzDraI1umO2An7jKNe75dBsxbI30E= +github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.36.0/go.mod h1:jk+iid9R4MN7UVDwSTK/ZDDO8WNhxnO2WVzfYOMLh+4= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi v1.23.3 h1:ByynKMsGZGmpUpnQ99y+lS7VxZrNt3mdagCnHd011Kk= github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi v1.23.3/go.mod h1:ZR4h87npHPuVQ2SEeoWMe+CO/HcS9g2iYMLnT5HawW8= +github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.31.7 h1:mHdnEFOQ0JVjsbjHGqkuE0pmEpnk/aWz8YxyyB4e2+E= +github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.31.7/go.mod h1:JsD+G3R0ZMWqjt7VDggNsc5SFl4hw+Sk8KQaRN1sltI= github.com/aws/aws-sdk-go-v2/service/shield v1.27.3 h1:SfjI6FuphzspGPvcRD8hjMD6wLUAE6vtJLGrui19j2s= github.com/aws/aws-sdk-go-v2/service/shield v1.27.3/go.mod h1:JpxjPa91y1hRb3G8xxzhOQFcK/r90it41jA/hD0q+Gg= github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= @@ -74,8 +78,8 @@ github.com/aws/aws-sdk-go-v2/service/wafregional v1.23.3 h1:7dr6En0/6KRFoz8VmnYk github.com/aws/aws-sdk-go-v2/service/wafregional v1.23.3/go.mod h1:24TtlRsv4LKAE3VnRJQhpatr8cpX0yj8NSzg8/lxOCw= github.com/aws/aws-sdk-go-v2/service/wafv2 v1.51.4 h1:1khBA5uryBRJoCb4G2iR5RT06BkfPEjjDCHAiRb8P3Q= github.com/aws/aws-sdk-go-v2/service/wafv2 v1.51.4/go.mod h1:QpFImaPGKNwa+MiZ+oo6LbV1PVQBapc0CnrAMRScoxM= -github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= -github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -341,8 +345,6 @@ github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= -github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/miekg/dns v1.1.25 h1:dFwPR6SfLtrSwgDcIq2bcU/gVutB4sNApq2HBdqcakg= github.com/miekg/dns v1.1.25/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= @@ -405,17 +407,17 @@ github.com/poy/onpar v1.1.2/go.mod h1:6X8FLNoxyr9kkmnlqpK6LSoiOtrO6MICtWwEuWkLjz github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g= -github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8= -github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc= +github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= +github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY= -github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= +github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= +github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.6.0/go.mod h1:eBmuwkDJBwy6iBfxCBob6t6dR6ENT/y+J+Zk0j9GMYc= -github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= -github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.3/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ= @@ -562,8 +564,8 @@ golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.12.0 h1:smVPGxink+n1ZI5pkQa8y6fZT0RW0MgCO5bFpepy4B4= -golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4= +golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= +golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go index 99fa4f5e5..db335e5f1 100644 --- a/pkg/aws/cloud.go +++ b/pkg/aws/cloud.go @@ -4,11 +4,15 @@ import ( "context" "fmt" awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/aws/ratelimit" + "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/config" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" smithymiddleware "github.com/aws/smithy-go/middleware" "net" "os" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/metrics" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle" "sigs.k8s.io/aws-load-balancer-controller/pkg/version" "strings" @@ -97,26 +101,37 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l } awsConfig, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(cfg.Region), - config.WithRetryMaxAttempts(cfg.MaxRetries), + config.WithRetryer(func() aws.Retryer { + return retry.NewStandard(func(o *retry.StandardOptions) { + o.RateLimiter = ratelimit.None + o.MaxAttempts = cfg.MaxRetries + }) + }), config.WithEC2IMDSEndpointMode(ec2IMDSEndpointMode), config.WithAPIOptions([]func(stack *smithymiddleware.Stack) error{ awsmiddleware.AddUserAgentKeyValue(userAgent, version.GitVersion), }), ) - //TODO: ADD metric collection and throttle configuration later - //if cfg.ThrottleConfig != nil { - // throttler := throttle.NewThrottler(cfg.ThrottleConfig) - // throttler.InjectHandlers(&sess.Handlers) - //} - - //if metricsRegisterer != nil { - // metricsCollector, err := metrics.NewCollector(metricsRegisterer) - // if err != nil { - // return nil, errors.Wrapf(err, "failed to initialize sdk metrics collector") - // } - // awsConfig.APIOptions = append(awsConfig.APIOptions, metricsCollector.CollectAPICallMetricMiddleware()) - //} + if cfg.ThrottleConfig != nil { + throttler := throttle.NewThrottler(cfg.ThrottleConfig) + awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *smithymiddleware.Stack) error { + return throttle.WithSDKRequestThrottleMiddleware(throttler)(stack) + }) + } + + if metricsRegisterer != nil { + metricsCollector, err := metrics.NewCollector(metricsRegisterer) + if err != nil { + return nil, errors.Wrapf(err, "failed to initialize sdk metrics collector") + } + awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *smithymiddleware.Stack) error { + return metrics.WithSDKCallMetricCollector(metricsCollector)(stack) + }) + awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *smithymiddleware.Stack) error { + return metrics.WithSDKRequestMetricCollector(metricsCollector)(stack) + }) + } ec2Service := services.NewEC2(awsConfig) diff --git a/pkg/aws/metrics/collector.go b/pkg/aws/metrics/collector.go index 8cb552461..aaf2e9492 100644 --- a/pkg/aws/metrics/collector.go +++ b/pkg/aws/metrics/collector.go @@ -1,12 +1,21 @@ package metrics import ( + "context" + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/smithy-go" + smithymiddleware "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "strconv" + "time" ) const ( - sdkHandlerCollectAPICallMetric = "collectAPICallMetric" - sdkHandlerCollectAPIRequestMetric = "collectAPIRequestMetric" + sdkMiddlewareCollectAPICallMetric = "collectAPICallMetric" + sdkMiddlewareCollectAPIRequestMetric = "collectAPIRequestMetric" ) type collector struct { @@ -23,105 +32,120 @@ func NewCollector(registerer prometheus.Registerer) (*collector, error) { }, nil } -// TODO : WIP Migrate metric collection -//func (c *collector) InjectHandlers(cfg aws.Config) { -// handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{ -// Name: sdkHandlerCollectAPIRequestMetric, -// Fn: c.collectAPIRequestMetric, -// }) -// handlers.Complete.PushFrontNamed(request.NamedHandler{ -// Name: sdkHandlerCollectAPICallMetric, -// Fn: c.collectAPICallMetric, -// }) -//} +/* +WithSDKCallMetricCollector is a middleware for the AWS SDK Go v2 that collects and reports metrics on API calls. +The call metrics are collected after the call is completed +*/ +func WithSDKCallMetricCollector(c *collector) func(stack *smithymiddleware.Stack) error { + return func(stack *smithymiddleware.Stack) error { + return stack.Initialize.Add(smithymiddleware.InitializeMiddlewareFunc(sdkMiddlewareCollectAPICallMetric, func( + ctx context.Context, input smithymiddleware.InitializeInput, next smithymiddleware.InitializeHandler, + ) ( + output smithymiddleware.InitializeOutput, metadata smithymiddleware.Metadata, err error, + ) { + start := time.Now() + out, metadata, err := next.HandleInitialize(ctx, input) + resp, ok := awsmiddleware.GetRawResponse(metadata).(*smithyhttp.Response) + if !ok { + // No raw response to wrap with. + return out, metadata, err + } + service := awsmiddleware.GetServiceID(ctx) + operation := operationForRequest(ctx) + statusCode := strconv.Itoa(resp.StatusCode) + errorCode := errorCodeForRequest(err) + retryCount := getRetryMetricsForRequest(metadata) + duration := time.Since(start) + c.instruments.apiCallsTotal.With(map[string]string{ + labelService: service, + labelOperation: operation, + labelStatusCode: statusCode, + labelErrorCode: errorCode, + }).Inc() + c.instruments.apiCallDurationSeconds.With(map[string]string{ + labelService: service, + labelOperation: operation, + }).Observe(duration.Seconds()) + c.instruments.apiCallRetries.With(map[string]string{ + labelService: service, + labelOperation: operation, + }).Observe(retryCount) + return out, metadata, err + }), smithymiddleware.After) + } +} + +/* +WithSDKRequestMetricCollector is a middleware for the AWS SDK Go v2 that collects and reports metrics on API requests. +The request metrics are collected after each retry attempts +*/ +func WithSDKRequestMetricCollector(c *collector) func(stack *smithymiddleware.Stack) error { + return func(stack *smithymiddleware.Stack) error { + return stack.Finalize.Add(smithymiddleware.FinalizeMiddlewareFunc(sdkMiddlewareCollectAPIRequestMetric, func( + ctx context.Context, input smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler, + ) ( + output smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, err error, + ) { + start := time.Now() + out, metadata, err := next.HandleFinalize(ctx, input) + resp, ok := awsmiddleware.GetRawResponse(metadata).(*smithyhttp.Response) + if !ok { + // No raw response to wrap with. + return out, metadata, err + } + service := awsmiddleware.GetServiceID(ctx) + operation := operationForRequest(ctx) + statusCode := strconv.Itoa(resp.StatusCode) + errorCode := errorCodeForRequest(err) + c.instruments.apiRequestsTotal.With(map[string]string{ + labelService: service, + labelOperation: operation, + labelStatusCode: statusCode, + labelErrorCode: errorCode, + }).Inc() + + requestDuration, ok := awsmiddleware.GetResponseAt(metadata) + if ok { + c.instruments.apiRequestDurationSecond.With(map[string]string{ + labelService: service, + labelOperation: operation, + }).Observe(requestDuration.Sub(start).Seconds()) + } + return out, metadata, err + }), smithymiddleware.After) + } +} -//func (c *collector) CollectAPICallMetricMiddleware() func(*smithymiddleware.Stack) error { -// return func(stack *smithymiddleware.Stack) error { -// return stack.Finalize.Add(smithymiddleware.FinalizeMiddlewareFunc("CollectAPICallMetricMiddleware", func(ctx context.Context, input smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler) (output smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, err error) { -// start := time.Now() -// output, metadata, err = next.HandleFinalize(ctx, input) -// service := awsmiddleware.GetServiceID(ctx) -// operation := awsmiddleware.GetOperationName(ctx) -// response, ok := output.Result.(*smithyhttp.Response) -// if !ok { -// return ok -// } -// statusCode := response.StatusCode -// errorCode := response -// duration := time.Since(r.Time) -// -// }), smithymiddleware.Before) -// } -// } -//} -// -//func (c *collector) collectAPIRequestMetric(r *request.Request) { -// service := r.ClientInfo.ServiceID -// operation := r.Operation.Name -// statusCode := statusCodeForRequest(r) -// errorCode := errorCodeForRequest(r) -// duration := time.Since(r.AttemptTime) -// -// c.instruments.apiRequestsTotal.With(map[string]string{ -// labelService: service, -// labelOperation: operation, -// labelStatusCode: statusCode, -// labelErrorCode: errorCode, -// }).Inc() -// c.instruments.apiRequestDurationSecond.With(map[string]string{ -// labelService: service, -// labelOperation: operation, -// }).Observe(duration.Seconds()) -//} -// -//func (c *collector) collectAPICallMetric(r *request.Request) { -// service := r.ClientInfo.ServiceID -// operation := r.Operation.Name -// statusCode := statusCodeForRequest(r) -// errorCode := errorCodeForRequest(r) -// duration := time.Since(r.Time) -// -// c.instruments.apiCallsTotal.With(map[string]string{ -// labelService: service, -// labelOperation: operation, -// labelStatusCode: statusCode, -// labelErrorCode: errorCode, -// }).Inc() -// c.instruments.apiCallDurationSeconds.With(map[string]string{ -// labelService: service, -// labelOperation: operation, -// }).Observe(duration.Seconds()) -// c.instruments.apiCallRetries.With(map[string]string{ -// labelService: service, -// labelOperation: operation, -// }).Observe(float64(r.RetryCount)) -//} -// -//// statusCodeForRequest returns the http status code for request. -//// if there is no http response, returns "0". -//func statusCodeForRequest(r *request.Request) string { -// if r.HTTPResponse != nil { -// return strconv.Itoa(r.HTTPResponse.StatusCode) -// } -// return "0" -//} -// -//// errorCodeForRequest returns the error code for request. -//// if no error happened, returns "". -//func errorCodeForRequest(r *request.Request) string { -// if r.Error != nil { -// if awserr, ok := r.Error.(awserr.Error); ok { -// return awserr.Code() -// } -// return "internal" -// } -// return "" -//} -// -//// operationForRequest returns the operation for request. -//func operationForRequest(r *request.Request) string { -// if r.Operation != nil { -// return r.Operation.Name -// } -// return "?" -//} +func getRetryMetricsForRequest(metadata smithymiddleware.Metadata) float64 { + retries := float64(0) + attemptResults, ok := retry.GetAttemptResults(metadata) + if ok { + for _, result := range attemptResults.Results { + if result.Retried { + retries++ + } + } + } + return retries +} + +// errorCodeForRequest returns the error code for response. +func errorCodeForRequest(err error) string { + errCode := "" + if err == nil { + return errCode + } + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return apiErr.ErrorCode() + } + return "internal" +} + +// operationForRequest returns the operation for request. +func operationForRequest(ctx context.Context) string { + if awsmiddleware.GetOperationName(ctx) != "" { + return awsmiddleware.GetOperationName(ctx) + } + return "?" +} diff --git a/pkg/aws/metrics/collector_test.go b/pkg/aws/metrics/collector_test.go index e25ced51e..db2579381 100644 --- a/pkg/aws/metrics/collector_test.go +++ b/pkg/aws/metrics/collector_test.go @@ -1,126 +1,48 @@ package metrics -// TODO: WIP Migrate metric collection -//import ( -// "errors" -// "github.com/stretchr/testify/assert" -// "net/http" -// "testing" -//) -// -//func Test_statusCodeForRequest(t *testing.T) { -// type args struct { -// r *request.Request -// } -// tests := []struct { -// name string -// args args -// want string -// }{ -// { -// name: "requests without http response", -// args: args{ -// r: &request.Request{}, -// }, -// want: "0", -// }, -// { -// name: "requests with http response", -// args: args{ -// r: &request.Request{ -// HTTPResponse: &http.Response{ -// StatusCode: 200, -// }, -// }, -// }, -// want: "200", -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// got := statusCodeForRequest(tt.args.r) -// assert.Equal(t, tt.want, got) -// }) -// } -//} -// -//func Test_errorCodeForRequest(t *testing.T) { -// type args struct { -// r *request.Request -// } -// tests := []struct { -// name string -// args args -// want string -// }{ -// { -// name: "requests without error", -// args: args{ -// r: &request.Request{}, -// }, -// want: "", -// }, -// { -// name: "requests with internal error", -// args: args{ -// r: &request.Request{ -// Error: errors.New("oops, some internal error"), -// }, -// }, -// want: "internal", -// }, -// { -// name: "requests with aws error", -// args: args{ -// r: &request.Request{ -// Error: &smithy.GenericAPIError{Code: "NotFoundException", Message: ""}, -// }, -// }, -// want: "NotFoundException", -// }, -// }, -// for _, tt := range tests{ -// t.Run(tt.name, func (t *testing.T){ -// got := errorCodeForRequest(tt.args.r) -// assert.Equal(t, tt.want, got) -// }) -// } -// } -// -// func -// Test_operationForRequest(t * testing.T) -// { -// type args struct { -// r *request.Request -// } -// tests := []struct { -// name string -// args args -// want string -// }{ -// { -// name: "requests without operation", -// args: args{ -// r: &request.Request{}, -// }, -// want: "?", -// }, -// { -// name: "requests with operation", -// args: args{ -// r: &request.Request{ -// Operation: &request.Operation{ -// Name: "DescribeMesh", -// }, -// }, -// }, -// want: "DescribeMesh", -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// got := operationForRequest(tt.args.r) -// assert.Equal(t, tt.want, got) -// }) -// } -// } +import ( + "errors" + "github.com/aws/smithy-go" + "github.com/stretchr/testify/assert" + + "testing" +) + +func Test_errorCodeForRequest(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + want string + }{ + { + name: "requests without error", + args: args{ + err: nil, + }, + want: "", + }, + { + name: "requests with internal error", + args: args{ + err: errors.New("oops, some internal error"), + }, + want: "internal", + }, + { + name: "requests with aws error", + args: args{ + err: &smithy.GenericAPIError{Code: "NotFoundException", Message: ""}, + }, + want: "NotFoundException", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := errorCodeForRequest(tt.args.err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/aws/services/elbv2.go b/pkg/aws/services/elbv2.go index 2532494cb..22aa99ccc 100644 --- a/pkg/aws/services/elbv2.go +++ b/pkg/aws/services/elbv2.go @@ -18,11 +18,9 @@ type ELBV2 interface { // wrapper to DescribeListenersPagesWithContext API, which aggregates paged results into list. DescribeListenersAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenersInput) ([]types.Listener, error) - // TODO : Implement these when the API paginator is available // wrapper to DescribeListenerCertificatesWithContext API, which aggregates paged results into list. DescribeListenerCertificatesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerCertificatesInput) ([]types.Certificate, error) - // TODO : Implement these when the API paginator is available // wrapper to DescribeRulesWithContext API, which aggregates paged results into list. DescribeRulesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) ([]types.Rule, error) @@ -61,7 +59,6 @@ type ELBV2 interface { } // NewELBV2 constructs new ELBV2 implementation. -// TODO custom resolver for gamma endpoint func NewELBV2(cfg aws.Config) ELBV2 { client := elasticloadbalancingv2.NewFromConfig(cfg) return &elbv2Client{elbv2Client: client} @@ -72,24 +69,6 @@ type elbv2Client struct { elbv2Client *elasticloadbalancingv2.Client } -// TODO : Paginate this method once paginators are available -func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerCertificatesInput) ([]types.Certificate, error) { - output, err := c.elbv2Client.DescribeListenerCertificates(ctx, input) - if err != nil { - return nil, err - } - return output.Certificates, nil -} - -// TODO : Paginate this method once paginators are available -func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) ([]types.Rule, error) { - output, err := c.elbv2Client.DescribeRules(ctx, input) - if err != nil { - return nil, err - } - return output.Rules, nil -} - func (c *elbv2Client) AddListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.AddListenerCertificatesInput) (*elasticloadbalancingv2.AddListenerCertificatesOutput, error) { return c.elbv2Client.AddListenerCertificates(ctx, input) } @@ -258,3 +237,29 @@ func (c *elbv2Client) DescribeListenersAsList(ctx context.Context, input *elasti } return result, nil } + +func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerCertificatesInput) ([]types.Certificate, error) { + var result []types.Certificate + paginator := elasticloadbalancingv2.NewDescribeListenerCertificatesPaginator(c.elbv2Client, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + result = append(result, output.Certificates...) + } + return result, nil +} + +func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) ([]types.Rule, error) { + var result []types.Rule + paginator := elasticloadbalancingv2.NewDescribeRulesPaginator(c.elbv2Client, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + result = append(result, output.Rules...) + } + return result, nil +} diff --git a/pkg/aws/throttle/condition.go b/pkg/aws/throttle/condition.go index 8c1507bc5..08ef49a48 100644 --- a/pkg/aws/throttle/condition.go +++ b/pkg/aws/throttle/condition.go @@ -1,35 +1,27 @@ package throttle -// TODO : WIP -//import ( -// "github.com/aws/aws-sdk-go/aws/request" -// smithyhttp "github.com/aws/smithy-go/transport/http" -// awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware/" -// "regexp" -//) -// -//type Condition func(r *awshttp.) bool -// -//func matchService(serviceID string) Condition { -// return func(r *awshttp.) bool { -// return awsmiddleware.GetServiceID() == serviceID -// } -//} -// -//func matchServiceOperation(serviceID string, operation string) Condition { -// return func(r *smithyhttp.Request) bool { -// if r.Operation == nil { -// return false -// } -// return r.ClientInfo.ServiceID == serviceID && r.Operation.Name == operation -// } -//} -// -//func matchServiceOperationPattern(serviceID string, operationPtn *regexp.Regexp) Condition { -// return func(r *smithyhttp.Request) bool { -// if r.Operation == nil { -// return false -// } -// return r.ClientInfo.ServiceID == serviceID && operationPtn.Match([]byte(r.Operation.Name)) -// } -//} +import ( + "context" + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "regexp" +) + +type Condition func(ctx context.Context) bool + +func matchService(serviceID string) Condition { + return func(ctx context.Context) bool { + return awsmiddleware.GetServiceID(ctx) == serviceID + } +} + +func matchServiceOperation(serviceID string, operation string) Condition { + return func(ctx context.Context) bool { + return awsmiddleware.GetServiceID(ctx) == serviceID && awsmiddleware.GetOperationName(ctx) == operation + } +} + +func matchServiceOperationPattern(serviceID string, operationPtn *regexp.Regexp) Condition { + return func(ctx context.Context) bool { + return awsmiddleware.GetServiceID(ctx) == serviceID && operationPtn.Match([]byte(awsmiddleware.GetOperationName(ctx))) + } +} diff --git a/pkg/aws/throttle/condition_test.go b/pkg/aws/throttle/condition_test.go index 5493a2159..eb607b51c 100644 --- a/pkg/aws/throttle/condition_test.go +++ b/pkg/aws/throttle/condition_test.go @@ -1,206 +1,44 @@ package throttle -// TODO : WIP -//import ( -// "github.com/stretchr/testify/assert" -// "regexp" -// "testing" -//) -// -//func Test_matchService(t *testing.T) { -// type args struct { -// serviceID string -// } -// tests := []struct { -// name string -// args args -// req *request.Request -// want bool -// }{ -// { -// name: "service matches", -// args: args{ -// serviceID: "App Mesh", -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// }, -// want: true, -// }, -// { -// name: "service mismatches", -// args: args{ -// serviceID: "App Mesh", -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "S3", -// }, -// }, -// want: false, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// predict := matchService(tt.args.serviceID) -// got := predict(tt.req) -// assert.Equal(t, tt.want, got) -// }) -// } -//} -// -//func Test_matchServiceOperation(t *testing.T) { -// type args struct { -// serviceID string -// operation string -// } -// tests := []struct { -// name string -// args args -// req *request.Request -// want bool -// }{ -// { -// name: "operation matches", -// args: args{ -// serviceID: "App Mesh", -// operation: "CreateMesh", -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// Operation: &request.Operation{ -// Name: "CreateMesh", -// }, -// }, -// want: true, -// }, -// { -// name: "operation mismatches", -// args: args{ -// serviceID: "App Mesh", -// operation: "CreateMesh", -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// Operation: &request.Operation{ -// Name: "DescribeMesh", -// }, -// }, -// want: false, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// predict := matchServiceOperation(tt.args.serviceID, tt.args.operation) -// got := predict(tt.req) -// assert.Equal(t, tt.want, got) -// }) -// } -//} -// -//func Test_matchServiceOperationPattern(t *testing.T) { -// type args struct { -// serviceID string -// operationPtn *regexp.Regexp -// } -// tests := []struct { -// name string -// args args -// req *request.Request -// want bool -// }{ -// { -// name: "operationPtn matches - case 1", -// args: args{ -// serviceID: "App Mesh", -// operationPtn: regexp.MustCompile("Create"), -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// Operation: &request.Operation{ -// Name: "CreateMesh", -// }, -// }, -// want: true, -// }, -// { -// name: "operationPtn matches - case 2", -// args: args{ -// serviceID: "App Mesh", -// operationPtn: regexp.MustCompile("Create.*"), -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// Operation: &request.Operation{ -// Name: "CreateMesh", -// }, -// }, -// want: true, -// }, -// { -// name: "operationPtn matches - case 3", -// args: args{ -// serviceID: "App Mesh", -// operationPtn: regexp.MustCompile("^Create"), -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// Operation: &request.Operation{ -// Name: "CreateMesh", -// }, -// }, -// want: true, -// }, -// { -// name: "operationPtn matches - case 4", -// args: args{ -// serviceID: "App Mesh", -// operationPtn: regexp.MustCompile("Mesh"), -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// Operation: &request.Operation{ -// Name: "CreateMesh", -// }, -// }, -// want: true, -// }, -// { -// name: "operationPtn mismatches", -// args: args{ -// serviceID: "App Mesh", -// operationPtn: regexp.MustCompile("Describe"), -// }, -// req: &request.Request{ -// ClientInfo: metadata.ClientInfo{ -// ServiceID: "App Mesh", -// }, -// Operation: &request.Operation{ -// Name: "CreateMesh", -// }, -// }, -// want: false, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// predict := matchServiceOperationPattern(tt.args.serviceID, tt.args.operationPtn) -// got := predict(tt.req) -// assert.Equal(t, tt.want, got) -// }) -// } -//} +import ( + "context" + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_matchService(t *testing.T) { + type args struct { + serviceID string + } + tests := []struct { + name string + args args + ctx context.Context + want bool + }{ + { + name: "service matches", + args: args{ + serviceID: "App Mesh", + }, + ctx: awsmiddleware.SetServiceID(context.TODO(), "App Mesh"), + want: true, + }, + { + name: "service mismatches", + args: args{ + serviceID: "App Mesh", + }, + ctx: awsmiddleware.SetServiceID(context.TODO(), "Some Service"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + predict := matchService(tt.args.serviceID) + got := predict(tt.ctx) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/aws/throttle/config_test.go b/pkg/aws/throttle/config_test.go index fda3d7e37..c27059945 100644 --- a/pkg/aws/throttle/config_test.go +++ b/pkg/aws/throttle/config_test.go @@ -1,258 +1,259 @@ package throttle -// TODO : WIP -//import ( -// "github.com/aws/aws-sdk-go-v2/service/appmesh" -// "github.com/aws/aws-sdk-go-v2/service/servicediscovery" -// "github.com/stretchr/testify/assert" -// "regexp" -// "testing" -//) -// -//func TestServiceOperationsThrottleConfig_String(t *testing.T) { -// type fields struct { -// value map[string][]throttleConfig -// } -// tests := []struct { -// name string -// fields fields -// want string -// }{ -// { -// name: "non-empty value", -// fields: fields{ -// value: map[string][]throttleConfig{ -// appmesh.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Describe"), -// r: 4.2, -// burst: 5, -// }, -// { -// operationPtn: regexp.MustCompile("CreateMesh"), -// r: 4.2, -// burst: 5, -// }, -// }, -// servicediscovery.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Describe"), -// r: 4.2, -// burst: 5, -// }, -// }, -// }, -// }, -// want: "App Mesh:^Describe=4.2:5,App Mesh:CreateMesh=4.2:5,ServiceDiscovery:^Describe=4.2:5", -// }, -// { -// name: "nil value", -// fields: fields{ -// value: nil, -// }, -// want: "", -// }, -// { -// name: "empty value", -// fields: fields{ -// value: nil, -// }, -// want: "", -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// c := &ServiceOperationsThrottleConfig{ -// value: tt.fields.value, -// } -// got := c.String() -// assert.Equal(t, tt.want, got) -// }) -// } -//} -// -//func TestServiceOperationsThrottleConfig_Set(t *testing.T) { -// type fields struct { -// value map[string][]throttleConfig -// } -// type args struct { -// val string -// } -// tests := []struct { -// name string -// fields fields -// args args -// want ServiceOperationsThrottleConfig -// wantErr error -// }{ -// { -// name: "when default value is nil", -// fields: fields{ -// value: nil, -// }, -// args: args{ -// val: "App Mesh:^Describe=4.2:5,App Mesh:CreateMesh=4.2:6,ServiceDiscovery:^Describe=4.2:7", -// }, -// want: ServiceOperationsThrottleConfig{ -// value: map[string][]throttleConfig{ -// appmesh.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Describe"), -// r: 4.2, -// burst: 5, -// }, -// { -// operationPtn: regexp.MustCompile("CreateMesh"), -// r: 4.2, -// burst: 6, -// }, -// }, -// servicediscovery.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Describe"), -// r: 4.2, -// burst: 7, -// }, -// }, -// }, -// }, -// }, -// { -// name: "when default value contains non-empty defaults", -// fields: fields{ -// value: map[string][]throttleConfig{ -// elbv2.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Create"), -// r: 4.2, -// burst: 4, -// }, -// }, -// }, -// }, -// args: args{ -// val: "App Mesh:^Describe=4.2:5,App Mesh:CreateMesh=4.2:6,ServiceDiscovery:^Describe=4.2:7", -// }, -// want: ServiceOperationsThrottleConfig{ -// value: map[string][]throttleConfig{ -// elbv2.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Create"), -// r: 4.2, -// burst: 4, -// }, -// }, -// appmesh.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Describe"), -// r: 4.2, -// burst: 5, -// }, -// { -// operationPtn: regexp.MustCompile("CreateMesh"), -// r: 4.2, -// burst: 6, -// }, -// }, -// servicediscovery.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Describe"), -// r: 4.2, -// burst: 7, -// }, -// }, -// }, -// }, -// }, -// { -// name: "when val is empty", -// fields: fields{ -// value: map[string][]throttleConfig{}, -// }, -// args: args{ -// val: "", -// }, -// wantErr: errors.Errorf(" must be formatted as serviceID:operationRegex=rate:burst"), -// }, -// { -// name: "when val is not valid format - case 1", -// fields: fields{ -// value: map[string][]throttleConfig{}, -// }, -// args: args{ -// val: "a=b=c", -// }, -// wantErr: errors.Errorf("a=b=c must be formatted as serviceID:operationRegex=rate:burst"), -// }, -// { -// name: "when val is not valid format - case 2", -// fields: fields{ -// value: map[string][]throttleConfig{}, -// }, -// args: args{ -// val: "a:b:c=4.2:5", -// }, -// wantErr: errors.Errorf("a:b:c must be formatted as serviceID:operationRegex"), -// }, -// { -// name: "when val is not valid format - case 3", -// fields: fields{ -// value: map[string][]throttleConfig{}, -// }, -// args: args{ -// val: "a:b=4.2:5:6", -// }, -// wantErr: errors.Errorf("4.2:5:6 must be formatted as rate:burst"), -// }, -// { -// name: "when operationPtn is not valid regex", -// fields: fields{ -// value: map[string][]throttleConfig{}, -// }, -// args: args{ -// val: "a:^[Describe=4.2:5", -// }, -// wantErr: errors.Errorf("^[Describe must be valid regex expression for operation"), -// }, -// { -// name: "when rate is not valid float number", -// fields: fields{ -// value: map[string][]throttleConfig{}, -// }, -// args: args{ -// val: "a:^Describe=4.x:5", -// }, -// wantErr: errors.Errorf("4.x must be valid float number as rate for operations per second"), -// }, -// { -// name: "when burst is not valid integer", -// fields: fields{ -// value: map[string][]throttleConfig{}, -// }, -// args: args{ -// val: "a:^Describe=4.2:5x", -// }, -// wantErr: errors.Errorf("5x must be valid integer as burst for operations"), -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// c := &ServiceOperationsThrottleConfig{ -// value: tt.fields.value, -// } -// err := c.Set(tt.args.val) -// if tt.wantErr != nil { -// assert.EqualError(t, err, tt.wantErr.Error()) -// } else { -// assert.NoError(t, err) -// assert.Equal(t, tt.want, *c) -// } -// }) -// } -//} -// -//func TestServiceOperationsThrottleConfig_Type(t *testing.T) { -// c := &ServiceOperationsThrottleConfig{} -// got := c.Type() -// assert.Equal(t, "serviceOperationsThrottleConfig", got) -//} +import ( + "github.com/aws/aws-sdk-go-v2/service/appmesh" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "regexp" + "testing" +) + +func TestServiceOperationsThrottleConfig_String(t *testing.T) { + type fields struct { + value map[string][]throttleConfig + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "non-empty value", + fields: fields{ + value: map[string][]throttleConfig{ + appmesh.ServiceID: { + { + operationPtn: regexp.MustCompile("^Describe"), + r: 4.2, + burst: 5, + }, + { + operationPtn: regexp.MustCompile("CreateMesh"), + r: 4.2, + burst: 5, + }, + }, + servicediscovery.ServiceID: { + { + operationPtn: regexp.MustCompile("^Describe"), + r: 4.2, + burst: 5, + }, + }, + }, + }, + want: "App Mesh:^Describe=4.2:5,App Mesh:CreateMesh=4.2:5,ServiceDiscovery:^Describe=4.2:5", + }, + { + name: "nil value", + fields: fields{ + value: nil, + }, + want: "", + }, + { + name: "empty value", + fields: fields{ + value: nil, + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ServiceOperationsThrottleConfig{ + value: tt.fields.value, + } + got := c.String() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestServiceOperationsThrottleConfig_Set(t *testing.T) { + type fields struct { + value map[string][]throttleConfig + } + type args struct { + val string + } + tests := []struct { + name string + fields fields + args args + want ServiceOperationsThrottleConfig + wantErr error + }{ + { + name: "when default value is nil", + fields: fields{ + value: nil, + }, + args: args{ + val: "App Mesh:^Describe=4.2:5,App Mesh:CreateMesh=4.2:6,ServiceDiscovery:^Describe=4.2:7", + }, + want: ServiceOperationsThrottleConfig{ + value: map[string][]throttleConfig{ + appmesh.ServiceID: { + { + operationPtn: regexp.MustCompile("^Describe"), + r: 4.2, + burst: 5, + }, + { + operationPtn: regexp.MustCompile("CreateMesh"), + r: 4.2, + burst: 6, + }, + }, + servicediscovery.ServiceID: { + { + operationPtn: regexp.MustCompile("^Describe"), + r: 4.2, + burst: 7, + }, + }, + }, + }, + }, + { + name: "when default value contains non-empty defaults", + fields: fields{ + value: map[string][]throttleConfig{ + elasticloadbalancingv2.ServiceID: { + { + operationPtn: regexp.MustCompile("^Create"), + r: 4.2, + burst: 4, + }, + }, + }, + }, + args: args{ + val: "App Mesh:^Describe=4.2:5,App Mesh:CreateMesh=4.2:6,ServiceDiscovery:^Describe=4.2:7", + }, + want: ServiceOperationsThrottleConfig{ + value: map[string][]throttleConfig{ + elasticloadbalancingv2.ServiceID: { + { + operationPtn: regexp.MustCompile("^Create"), + r: 4.2, + burst: 4, + }, + }, + appmesh.ServiceID: { + { + operationPtn: regexp.MustCompile("^Describe"), + r: 4.2, + burst: 5, + }, + { + operationPtn: regexp.MustCompile("CreateMesh"), + r: 4.2, + burst: 6, + }, + }, + servicediscovery.ServiceID: { + { + operationPtn: regexp.MustCompile("^Describe"), + r: 4.2, + burst: 7, + }, + }, + }, + }, + }, + { + name: "when val is empty", + fields: fields{ + value: map[string][]throttleConfig{}, + }, + args: args{ + val: "", + }, + wantErr: errors.Errorf(" must be formatted as serviceID:operationRegex=rate:burst"), + }, + { + name: "when val is not valid format - case 1", + fields: fields{ + value: map[string][]throttleConfig{}, + }, + args: args{ + val: "a=b=c", + }, + wantErr: errors.Errorf("a=b=c must be formatted as serviceID:operationRegex=rate:burst"), + }, + { + name: "when val is not valid format - case 2", + fields: fields{ + value: map[string][]throttleConfig{}, + }, + args: args{ + val: "a:b:c=4.2:5", + }, + wantErr: errors.Errorf("a:b:c must be formatted as serviceID:operationRegex"), + }, + { + name: "when val is not valid format - case 3", + fields: fields{ + value: map[string][]throttleConfig{}, + }, + args: args{ + val: "a:b=4.2:5:6", + }, + wantErr: errors.Errorf("4.2:5:6 must be formatted as rate:burst"), + }, + { + name: "when operationPtn is not valid regex", + fields: fields{ + value: map[string][]throttleConfig{}, + }, + args: args{ + val: "a:^[Describe=4.2:5", + }, + wantErr: errors.Errorf("^[Describe must be valid regex expression for operation"), + }, + { + name: "when rate is not valid float number", + fields: fields{ + value: map[string][]throttleConfig{}, + }, + args: args{ + val: "a:^Describe=4.x:5", + }, + wantErr: errors.Errorf("4.x must be valid float number as rate for operations per second"), + }, + { + name: "when burst is not valid integer", + fields: fields{ + value: map[string][]throttleConfig{}, + }, + args: args{ + val: "a:^Describe=4.2:5x", + }, + wantErr: errors.Errorf("5x must be valid integer as burst for operations"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ServiceOperationsThrottleConfig{ + value: tt.fields.value, + } + err := c.Set(tt.args.val) + if tt.wantErr != nil { + assert.EqualError(t, err, tt.wantErr.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, *c) + } + }) + } +} + +func TestServiceOperationsThrottleConfig_Type(t *testing.T) { + c := &ServiceOperationsThrottleConfig{} + got := c.Type() + assert.Equal(t, "serviceOperationsThrottleConfig", got) +} diff --git a/pkg/aws/throttle/throttler.go b/pkg/aws/throttle/throttler.go index 4b5a7b7eb..730bf8d6c 100644 --- a/pkg/aws/throttle/throttler.go +++ b/pkg/aws/throttle/throttler.go @@ -1,71 +1,81 @@ package throttle -// TODO : WIP -//import ( -// "github.com/aws/aws-sdk-go/aws/request" -// "golang.org/x/time/rate" -// "regexp" -//) -// -//const sdkHandlerRequestThrottle = "requestThrottle" -// -//type conditionLimiter struct { -// condition Condition -// limiter *rate.Limiter -//} -// -//type throttler struct { -// conditionLimiters []conditionLimiter -//} -// -//// NewThrottler constructs new request throttler instance. -//func NewThrottler(config *ServiceOperationsThrottleConfig) *throttler { -// throttler := &throttler{} -// for serviceID, operationsThrottleConfigs := range config.value { -// for _, operationsThrottleConfig := range operationsThrottleConfigs { -// throttler = throttler.WithOperationPatternThrottle( -// serviceID, -// operationsThrottleConfig.operationPtn, -// operationsThrottleConfig.r, -// operationsThrottleConfig.burst) -// } -// } -// return throttler -//} -// -//func (t *throttler) WithConditionThrottle(condition Condition, r rate.Limit, burst int) *throttler { -// limiter := rate.NewLimiter(r, burst) -// t.conditionLimiters = append(t.conditionLimiters, conditionLimiter{ -// condition: condition, -// limiter: limiter, -// }) -// return t -//} -// -//func (t *throttler) WithServiceThrottle(serviceID string, r rate.Limit, burst int) *throttler { -// return t.WithConditionThrottle(matchService(serviceID), r, burst) -//} -// -//func (t *throttler) WithOperationThrottle(serviceID string, operation string, r rate.Limit, burst int) *throttler { -// return t.WithConditionThrottle(matchServiceOperation(serviceID, operation), r, burst) -//} -// -//func (t *throttler) WithOperationPatternThrottle(serviceID string, operationPtn *regexp.Regexp, r rate.Limit, burst int) *throttler { -// return t.WithConditionThrottle(matchServiceOperationPattern(serviceID, operationPtn), r, burst) -//} -// -//func (t *throttler) InjectHandlers(handlers *request.Handlers) { -// handlers.Sign.PushFrontNamed(request.NamedHandler{ -// Name: sdkHandlerRequestThrottle, -// Fn: t.beforeSign, -// }) -//} -// -//// beforeSign is added to the Sign chain; called before each request -//func (t *throttler) beforeSign(r *request.Request) { -// for _, conditionLimiter := range t.conditionLimiters { -// if conditionLimiter.condition(r) { -// conditionLimiter.limiter.Wait(r.Context()) -// } -// } -//} +import ( + "context" + smithymiddleware "github.com/aws/smithy-go/middleware" + "golang.org/x/time/rate" + "regexp" +) + +const sdkHandlerRequestThrottle = "requestThrottle" + +type conditionLimiter struct { + condition Condition + limiter *rate.Limiter +} + +type throttler struct { + conditionLimiters []conditionLimiter +} + +// NewThrottler constructs new request throttler instance. +func NewThrottler(config *ServiceOperationsThrottleConfig) *throttler { + throttler := &throttler{} + for serviceID, operationsThrottleConfigs := range config.value { + for _, operationsThrottleConfig := range operationsThrottleConfigs { + throttler = throttler.WithOperationPatternThrottle( + serviceID, + operationsThrottleConfig.operationPtn, + operationsThrottleConfig.r, + operationsThrottleConfig.burst) + } + } + return throttler +} + +func (t *throttler) WithConditionThrottle(condition Condition, r rate.Limit, burst int) *throttler { + limiter := rate.NewLimiter(r, burst) + t.conditionLimiters = append(t.conditionLimiters, conditionLimiter{ + condition: condition, + limiter: limiter, + }) + return t +} + +func (t *throttler) WithServiceThrottle(serviceID string, r rate.Limit, burst int) *throttler { + return t.WithConditionThrottle(matchService(serviceID), r, burst) +} + +func (t *throttler) WithOperationThrottle(serviceID string, operation string, r rate.Limit, burst int) *throttler { + return t.WithConditionThrottle(matchServiceOperation(serviceID, operation), r, burst) +} + +func (t *throttler) WithOperationPatternThrottle(serviceID string, operationPtn *regexp.Regexp, r rate.Limit, burst int) *throttler { + return t.WithConditionThrottle(matchServiceOperationPattern(serviceID, operationPtn), r, burst) +} + +/* +WithSDKRequestThrottleMiddleware is a middleware that applies client side rate limiting to the clients. This is added in finalize step of middleware stack +and is called before each request in middleware chain +*/ +func WithSDKRequestThrottleMiddleware(throttler *throttler) func(stack *smithymiddleware.Stack) error { + return func(stack *smithymiddleware.Stack) error { + return stack.Finalize.Add(smithymiddleware.FinalizeMiddlewareFunc(sdkHandlerRequestThrottle, func( + ctx context.Context, input smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler, + ) ( + output smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, err error, + ) { + throttler.beforeSign(ctx) + return next.HandleFinalize(ctx, input) + }), smithymiddleware.Before) + } +} + +// beforeSign is added to the Finalize step of middleware stack; called before each request +func (t *throttler) beforeSign(ctx context.Context) { + for _, conditionLimiter := range t.conditionLimiters { + if conditionLimiter.condition(ctx) { + conditionLimiter.limiter.Wait(ctx) + } + } +} diff --git a/pkg/aws/throttle/throttler_test.go b/pkg/aws/throttle/throttler_test.go index e52921b66..8c91107cf 100644 --- a/pkg/aws/throttle/throttler_test.go +++ b/pkg/aws/throttle/throttler_test.go @@ -1,285 +1,234 @@ package throttle -// TODO : WIP -//import ( -// "context" -// "github.com/aws/aws-sdk-go-v2/service/appmesh" -// "github.com/aws/aws-sdk-go-v2/service/servicediscovery" -// "github.com/stretchr/testify/assert" -// "golang.org/x/time/rate" -// "net/http" -// "regexp" -// "sync" -// "sync/atomic" -// "testing" -// "time" -//) -// -//func Test_NewThrottler(t *testing.T) { -// config := ServiceOperationsThrottleConfig{ -// value: map[string][]throttleConfig{ -// appmesh.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Describe"), -// r: 4.2, -// burst: 5, -// }, -// { -// operationPtn: regexp.MustCompile("CreateMesh"), -// r: 3.8, -// burst: 4, -// }, -// }, -// servicediscovery.ServiceID: { -// { -// operationPtn: regexp.MustCompile("^Create"), -// r: 1.2, -// burst: 2, -// }, -// }, -// }, -// } -// -// throttler := NewThrottler(&config) -// assert.Equal(t, 3, len(throttler.conditionLimiters)) -//} -// -//func Test_throttler_WithConditionThrottle(t *testing.T) { -// throttler := &throttler{} -// throttler.WithConditionThrottle(matchService(appmesh.ServiceID), 5.0, 10) -// -// assert.Equal(t, 1, len(throttler.conditionLimiters)) -// -// cl := throttler.conditionLimiters[0] -// assert.True(t, cl.condition(&request.Request{ClientInfo: metadata.ClientInfo{ServiceID: appmesh.ServiceID}})) -// assert.Equal(t, rate.NewLimiter(5.0, 10), cl.limiter) -//} -// -//func Test_throttler_WithServiceThrottle(t *testing.T) { -// throttler := &throttler{} -// throttler.WithServiceThrottle(appmesh.ServiceID, 5.0, 10) -// -// assert.Equal(t, 1, len(throttler.conditionLimiters)) -// -// cl := throttler.conditionLimiters[0] -// assert.True(t, cl.condition(&request.Request{ClientInfo: metadata.ClientInfo{ServiceID: appmesh.ServiceID}})) -// assert.Equal(t, rate.NewLimiter(5.0, 10), cl.limiter) -//} -// -//func Test_throttler_WithOperationThrottle(t *testing.T) { -// throttler := &throttler{} -// throttler.WithOperationThrottle(appmesh.ServiceID, "CreateMesh", 5.0, 10) -// -// assert.Equal(t, 1, len(throttler.conditionLimiters)) -// -// cl := throttler.conditionLimiters[0] -// assert.True(t, cl.condition(&request.Request{ -// ClientInfo: metadata.ClientInfo{ServiceID: appmesh.ServiceID}, -// Operation: &request.Operation{Name: "CreateMesh"}, -// })) -// assert.Equal(t, rate.NewLimiter(5.0, 10), cl.limiter) -//} -// -//func Test_throttler_WithOperationPatternThrottle(t *testing.T) { -// throttler := &throttler{} -// throttler.WithOperationPatternThrottle(appmesh.ServiceID, regexp.MustCompile("^Create"), 5.0, 10) -// -// assert.Equal(t, 1, len(throttler.conditionLimiters)) -// -// cl := throttler.conditionLimiters[0] -// assert.True(t, cl.condition(&request.Request{ -// ClientInfo: metadata.ClientInfo{ServiceID: appmesh.ServiceID}, -// Operation: &request.Operation{Name: "CreateMesh"}, -// })) -// assert.Equal(t, rate.NewLimiter(5.0, 10), cl.limiter) -//} -// -//func Test_throttler_InjectHandlers(t *testing.T) { -// throttler := &throttler{} -// handlers := request.Handlers{} -// throttler.InjectHandlers(&handlers) -// assert.Equal(t, 1, handlers.Sign.Len()) -//} -// -//// Test beforeSign to check whether throttle applies correctly. -//// Note: the validCallsCount checks whether the observed calls falls into [ideal-1, ideal+1] -//// it shouldn't be too precisely to avoid false alarms caused by CPU load when running tests. -//// structure your limits and testQPS, so that the expect QPS with/without throttle differs dramatically. (e.g. 10x) -//func Test_throttler_beforeSign(t *testing.T) { -// type fields struct { -// conditionLimiters []conditionLimiter -// } -// type args struct { -// r *request.Request -// } -// tests := []struct { -// name string -// fields fields -// args args -// testDuration time.Duration -// testQPS int64 -// validCallsCount func(elapsedDuration time.Duration, observedCallsCount int64) -// }{ -// { -// name: "[single matching condition] throttle should applies", -// fields: fields{ -// conditionLimiters: []conditionLimiter{ -// { -// condition: func(r *request.Request) bool { -// return true -// }, -// limiter: rate.NewLimiter(10, 5), -// }, -// }, -// }, -// args: args{ -// r: &request.Request{ -// HTTPRequest: &http.Request{}, -// }, -// }, -// testQPS: 100, -// validCallsCount: func(elapsedDuration time.Duration, count int64) { -// ideal := 5 + 10*elapsedDuration.Seconds() -// // We should never get more requests than allowed. -// if want := int64(ideal * 1.1); count > want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// // We should get very close to the number of requests allowed. -// if want := int64(ideal * 0.9); count < want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// }, -// }, -// { -// name: "[single non-matching condition] throttle shouldn't applies", -// fields: fields{ -// conditionLimiters: []conditionLimiter{ -// { -// condition: func(r *request.Request) bool { -// return false -// }, -// limiter: rate.NewLimiter(10, 5), -// }, -// }, -// }, -// args: args{ -// r: &request.Request{ -// HTTPRequest: &http.Request{}, -// }, -// }, -// testQPS: 100, -// validCallsCount: func(elapsedDuration time.Duration, count int64) { -// ideal := 100 * elapsedDuration.Seconds() -// // We should never get more requests than allowed. -// if want := int64(ideal * 1.1); count > want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// // We should get very close to the number of requests allowed. -// if want := int64(ideal * 0.9); count < want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// }, -// }, -// { -// name: "[two condition, one matching and another non-matching] matching throttle should applies", -// fields: fields{ -// conditionLimiters: []conditionLimiter{ -// { -// condition: func(r *request.Request) bool { -// return true -// }, -// limiter: rate.NewLimiter(10, 5), -// }, -// { -// condition: func(r *request.Request) bool { -// return false -// }, -// limiter: rate.NewLimiter(1, 5), -// }, -// }, -// }, -// args: args{ -// r: &request.Request{ -// HTTPRequest: &http.Request{}, -// }, -// }, -// testQPS: 100, -// validCallsCount: func(elapsedDuration time.Duration, count int64) { -// ideal := 5 + 10*elapsedDuration.Seconds() -// // We should never get more requests than allowed. -// if want := int64(ideal * 1.1); count > want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// // We should get very close to the number of requests allowed. -// if want := int64(ideal * 0.9); count < want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// }, -// }, -// { -// name: "[two condition, both matching] most restrictive throttle should applies", -// fields: fields{ -// conditionLimiters: []conditionLimiter{ -// { -// condition: func(r *request.Request) bool { -// return true -// }, -// limiter: rate.NewLimiter(10, 5), -// }, -// { -// condition: func(r *request.Request) bool { -// return true -// }, -// limiter: rate.NewLimiter(1, 5), -// }, -// }, -// }, -// args: args{ -// r: &request.Request{ -// HTTPRequest: &http.Request{}, -// }, -// }, -// testQPS: 100, -// validCallsCount: func(elapsedDuration time.Duration, count int64) { -// ideal := 5 + 1*elapsedDuration.Seconds() -// // We should never get more requests than allowed. -// if want := int64(ideal * 1.1); count > want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// // We should get very close to the number of requests allowed. -// if want := int64(ideal * 0.9); count < want { -// t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) -// } -// }, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t1 *testing.T) { -// throttler := &throttler{ -// conditionLimiters: tt.fields.conditionLimiters, -// } -// -// ctx, cancel := context.WithCancel(context.Background()) -// tt.args.r.SetContext(ctx) -// -// observedCount := int64(0) -// start := time.Now() -// end := start.Add(time.Second * 1) -// testQPSThrottle := time.Tick(time.Second / time.Duration(tt.testQPS)) -// var wg sync.WaitGroup -// for time.Now().Before(end) { -// wg.Add(1) -// go func() { -// throttler.beforeSign(tt.args.r) -// atomic.AddInt64(&observedCount, 1) -// wg.Done() -// }() -// <-testQPSThrottle -// } -// elapsed := time.Since(start) -// tt.validCallsCount(elapsed, atomic.LoadInt64(&observedCount)) -// cancel() -// wg.Wait() -// }) -// } -//} +import ( + "context" + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/service/appmesh" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" + "regexp" + "sync" + "sync/atomic" + "testing" + "time" +) + +func Test_NewThrottler(t *testing.T) { + config := ServiceOperationsThrottleConfig{ + value: map[string][]throttleConfig{ + appmesh.ServiceID: { + { + operationPtn: regexp.MustCompile("^Describe"), + r: 4.2, + burst: 5, + }, + { + operationPtn: regexp.MustCompile("CreateMesh"), + r: 3.8, + burst: 4, + }, + }, + servicediscovery.ServiceID: { + { + operationPtn: regexp.MustCompile("^Create"), + r: 1.2, + burst: 2, + }, + }, + }, + } + + throttler := NewThrottler(&config) + assert.Equal(t, 3, len(throttler.conditionLimiters)) +} + +func Test_throttler_WithConditionThrottle(t *testing.T) { + throttler := &throttler{} + throttler.WithConditionThrottle(matchService(appmesh.ServiceID), 5.0, 10) + + assert.Equal(t, 1, len(throttler.conditionLimiters)) + + cl := throttler.conditionLimiters[0] + ctx := awsmiddleware.SetServiceID(context.TODO(), appmesh.ServiceID) + assert.True(t, cl.condition(ctx)) + assert.Equal(t, rate.NewLimiter(5.0, 10), cl.limiter) +} + +func Test_throttler_WithServiceThrottle(t *testing.T) { + throttler := &throttler{} + throttler.WithServiceThrottle(appmesh.ServiceID, 5.0, 10) + + assert.Equal(t, 1, len(throttler.conditionLimiters)) + + cl := throttler.conditionLimiters[0] + ctx := awsmiddleware.SetServiceID(context.TODO(), appmesh.ServiceID) + assert.True(t, cl.condition(ctx)) + assert.Equal(t, rate.NewLimiter(5.0, 10), cl.limiter) +} + +// Test beforeSign to check whether throttle applies correctly. +// Note: the validCallsCount checks whether the observed calls falls into [ideal-1, ideal+1] +// it shouldn't be too precisely to avoid false alarms caused by CPU load when running tests. +// structure your limits and testQPS, so that the expect QPS with/without throttle differs dramatically. (e.g. 10x) +func Test_throttler_beforeSign(t *testing.T) { + type fields struct { + conditionLimiters []conditionLimiter + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + testDuration time.Duration + testQPS int64 + validCallsCount func(elapsedDuration time.Duration, observedCallsCount int64) + }{ + { + name: "[single matching condition] throttle should applies", + fields: fields{ + conditionLimiters: []conditionLimiter{ + { + condition: func(ctx context.Context) bool { + return true + }, + limiter: rate.NewLimiter(10, 5), + }, + }, + }, + args: args{}, + testQPS: 100, + validCallsCount: func(elapsedDuration time.Duration, count int64) { + ideal := 5 + 10*elapsedDuration.Seconds() + // We should never get more requests than allowed. + if want := int64(ideal * 1.1); count > want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + // We should get very close to the number of requests allowed. + if want := int64(ideal * 0.9); count < want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + }, + }, + { + name: "[single non-matching condition] throttle shouldn't applies", + fields: fields{ + conditionLimiters: []conditionLimiter{ + { + condition: func(ctx context.Context) bool { + return false + }, + limiter: rate.NewLimiter(10, 5), + }, + }, + }, + args: args{}, + testQPS: 100, + validCallsCount: func(elapsedDuration time.Duration, count int64) { + ideal := 100 * elapsedDuration.Seconds() + // We should never get more requests than allowed. + if want := int64(ideal * 1.1); count > want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + // We should get very close to the number of requests allowed. + if want := int64(ideal * 0.9); count < want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + }, + }, + { + name: "[two condition, one matching and another non-matching] matching throttle should applies", + fields: fields{ + conditionLimiters: []conditionLimiter{ + { + condition: func(ctx context.Context) bool { + return true + }, + limiter: rate.NewLimiter(10, 5), + }, + { + condition: func(ctx context.Context) bool { + return false + }, + limiter: rate.NewLimiter(1, 5), + }, + }, + }, + args: args{}, + testQPS: 100, + validCallsCount: func(elapsedDuration time.Duration, count int64) { + ideal := 5 + 10*elapsedDuration.Seconds() + // We should never get more requests than allowed. + if want := int64(ideal * 1.1); count > want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + // We should get very close to the number of requests allowed. + if want := int64(ideal * 0.9); count < want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + }, + }, + { + name: "[two condition, both matching] most restrictive throttle should applies", + fields: fields{ + conditionLimiters: []conditionLimiter{ + { + condition: func(ctx context.Context) bool { + return true + }, + limiter: rate.NewLimiter(10, 5), + }, + { + condition: func(ctx context.Context) bool { + return true + }, + limiter: rate.NewLimiter(1, 5), + }, + }, + }, + args: args{}, + testQPS: 100, + validCallsCount: func(elapsedDuration time.Duration, count int64) { + ideal := 5 + 1*elapsedDuration.Seconds() + // We should never get more requests than allowed. + if want := int64(ideal * 1.1); count > want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + // We should get very close to the number of requests allowed. + if want := int64(ideal * 0.9); count < want { + t.Errorf("count = %d, want %d (ideal %f", count, want, ideal) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t1 *testing.T) { + throttler := &throttler{ + conditionLimiters: tt.fields.conditionLimiters, + } + + ctx, cancel := context.WithCancel(context.Background()) + + observedCount := int64(0) + start := time.Now() + end := start.Add(time.Second * 1) + testQPSThrottle := time.Tick(time.Second / time.Duration(tt.testQPS)) + var wg sync.WaitGroup + for time.Now().Before(end) { + wg.Add(1) + go func() { + throttler.beforeSign(ctx) + atomic.AddInt64(&observedCount, 1) + wg.Done() + }() + <-testQPSThrottle + } + elapsed := time.Since(start) + tt.validCallsCount(elapsed, atomic.LoadInt64(&observedCount)) + cancel() + wg.Wait() + }) + } +}