diff --git a/docs/README.md b/docs/README.md index aa50132..ce13b0d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -13,6 +13,7 @@ - [Key-Value Pairs](libs/pairs.md) - [Readiness Checks](libs/readiness.md) - [Kubernetes Resource Management](libs/resource.md) +- [Retrying k8s Operations](libs/retry.md) - [Kubernetes Resource Status Updating](libs/status.md) - [Testing](libs/testing.md) - [Thread Management](libs/threads.md) diff --git a/docs/libs/retry.md b/docs/libs/retry.md new file mode 100644 index 0000000..14b26b5 --- /dev/null +++ b/docs/libs/retry.md @@ -0,0 +1,26 @@ +# Retrying k8s Operations + +The `pkg/retry` package contains a `Client` that wraps a `client.Client` while implementing the interface itself and retries any failed (= the returned error is not `nil`) operation. +Methods that don't return an error are simply forwarded to the internal client. + +In addition to the `client.Client` interface's methods, the `retry.Client` also has `CreateOrUpdate` and `CreateOrPatch` methods, which use the corresponding controller-runtime implementations internally. + +The default retry parameters are: +- retry every 100 milliseconds +- don't increase retry interval +- no maximum number of attempts +- timeout after 1 second + +The `retry.Client` struct has builder-style methods to configure the parameters: +```golang +retryingClient := retry.NewRetryingClient(myClient). + WithTimeout(10 * time.Second). // try for at max 10 seconds + WithInterval(500 * time.Millisecond). // try every 500 milliseconds, but ... + WithBackoffMultiplier(2.0) // ... double the interval after each retry +``` + +For convenience, the `clusters.Cluster` type can return a retrying client for its internal client: +```golang +// cluster is of type *clusters.Cluster +err := cluster.Retry().WithMaxAttempts(3).Get(...) +``` diff --git a/pkg/clusters/cluster.go b/pkg/clusters/cluster.go index 91a8d65..a75e976 100644 --- a/pkg/clusters/cluster.go +++ b/pkg/clusters/cluster.go @@ -12,6 +12,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/cluster" "github.com/openmcp-project/controller-utils/pkg/controller" + "github.com/openmcp-project/controller-utils/pkg/retry" ) type Cluster struct { @@ -238,6 +239,12 @@ func (c *Cluster) APIServerEndpoint() string { return c.restCfg.Host } +// Retry returns a retrying client for the cluster. +// Returns nil if the client has not been initialized. +func (c *Cluster) Retry() *retry.Client { + return retry.NewRetryingClient(c.Client()) +} + ///////////////// // Serializing // ///////////////// diff --git a/pkg/collections/maps/utils.go b/pkg/collections/maps/utils.go index 38f33e5..b801345 100644 --- a/pkg/collections/maps/utils.go +++ b/pkg/collections/maps/utils.go @@ -1,6 +1,11 @@ package maps -import "github.com/openmcp-project/controller-utils/pkg/collections/filters" +import ( + "k8s.io/utils/ptr" + + "github.com/openmcp-project/controller-utils/pkg/collections/filters" + "github.com/openmcp-project/controller-utils/pkg/pairs" +) // Filter filters a map by applying a filter function to each key-value pair. // Only the entries for which the filter function returns true are kept in the copy. @@ -50,3 +55,34 @@ func Intersect[K comparable, V any](source map[K]V, maps ...map[K]V) map[K]V { return res } + +// MapKeys returns a slice of all keys in the map. +// The order is unspecified. +// The keys are not deep-copied, so changes to them could affect the original map. +func MapKeys[K comparable, V any](m map[K]V) []K { + keys := make([]K, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// MapValues returns a slice of all values in the map. +// The order is unspecified. +// The values are not deep-copied, so changes to them could affect the original map. +func MapValues[K comparable, V any](m map[K]V) []V { + values := make([]V, 0, len(m)) + for _, v := range m { + values = append(values, v) + } + return values +} + +// GetAny returns an arbitrary key-value pair from the map as a pointer to a pairs.Pair. +// If the map is empty, it returns nil. +func GetAny[K comparable, V any](m map[K]V) *pairs.Pair[K, V] { + for k, v := range m { + return ptr.To(pairs.New(k, v)) + } + return nil +} diff --git a/pkg/collections/maps/utils_test.go b/pkg/collections/maps/utils_test.go index e3f456b..949c25b 100644 --- a/pkg/collections/maps/utils_test.go +++ b/pkg/collections/maps/utils_test.go @@ -7,7 +7,7 @@ import ( "github.com/openmcp-project/controller-utils/pkg/collections/maps" ) -var _ = Describe("LinkedIterator Tests", func() { +var _ = Describe("Map Utils Tests", func() { Context("Merge", func() { @@ -60,4 +60,56 @@ var _ = Describe("LinkedIterator Tests", func() { }) + Context("MapKeys", func() { + + It("should return all keys in the map", func() { + m1 := map[string]string{"foo": "bar", "bar": "baz", "foobar": "foobaz"} + keys := maps.MapKeys(m1) + Expect(keys).To(ConsistOf("foo", "bar", "foobar")) + Expect(len(keys)).To(Equal(3)) + }) + + It("should return an empty slice for an empty or nil map", func() { + var nilMap map[string]string + Expect(maps.MapKeys(nilMap)).To(BeEmpty()) + Expect(maps.MapKeys(map[string]string{})).To(BeEmpty()) + }) + + }) + + Context("MapValues", func() { + + It("should return all values in the map", func() { + m1 := map[string]string{"foo": "bar", "bar": "baz", "foobar": "foobaz"} + values := maps.MapValues(m1) + Expect(values).To(ConsistOf("bar", "baz", "foobaz")) + Expect(len(values)).To(Equal(3)) + }) + + It("should return an empty slice for an empty or nil map", func() { + var nilMap map[string]string + Expect(maps.MapValues(nilMap)).To(BeEmpty()) + Expect(maps.MapValues(map[string]string{})).To(BeEmpty()) + }) + + }) + + Context("GetAny", func() { + + It("should return a key-value pair from the map", func() { + m1 := map[string]string{"foo": "bar", "bar": "baz", "foobar": "foobaz"} + pair := maps.GetAny(m1) + Expect(pair).ToNot(BeNil()) + Expect(pair.Key).To(BeElementOf("foo", "bar", "foobar")) + Expect(m1[pair.Key]).To(Equal(pair.Value)) + }) + + It("should return nil for an empty or nil map", func() { + var nilMap map[string]string + Expect(maps.GetAny(nilMap)).To(BeNil()) + Expect(maps.GetAny(map[string]string{})).To(BeNil()) + }) + + }) + }) diff --git a/pkg/collections/utils.go b/pkg/collections/utils.go new file mode 100644 index 0000000..f5b03f2 --- /dev/null +++ b/pkg/collections/utils.go @@ -0,0 +1,78 @@ +package collections + +// ProjectSlice takes a slice and a projection function and applies this function to each element of the slice. +// It returns a new slice containing the results of the projection. +// The original slice is not modified. +// If the projection function is nil, it returns nil. +func ProjectSlice[X any, Y any](src []X, project func(X) Y) []Y { + if project == nil { + return nil + } + res := make([]Y, len(src)) + for i, src := range src { + res[i] = project(src) + } + return res +} + +// ProjectMapToSlice takes a map and a projection function and applies this function to each key-value pair in the map. +// It returns a new slice containing the results of the projection. +// The original map is not modified. +// If the projection function is nil, it returns nil. +func ProjectMapToSlice[K comparable, V any, R any](src map[K]V, project func(K, V) R) []R { + if project == nil { + return nil + } + res := make([]R, 0, len(src)) + for k, v := range src { + res = append(res, project(k, v)) + } + return res +} + +// ProjectMapToMap takes a map and a projection function and applies this function to each key-value pair in the map. +// It returns a new map containing the results of the projection. +// The original map is not modified. +// Note that the resulting map may be smaller if the projection function does not guarantee unique keys. +// If the projection function is nil, it returns nil. +func ProjectMapToMap[K1 comparable, V1 any, K2 comparable, V2 any](src map[K1]V1, project func(K1, V1) (K2, V2)) map[K2]V2 { + if project == nil { + return nil + } + res := make(map[K2]V2, len(src)) + for k, v := range src { + newK, newV := project(k, v) + res[newK] = newV + } + return res +} + +// AggregateSlice takes a slice, an aggregation function and an initial value. +// It applies the aggregation function to each element of the slice, also passing in the current result. +// For the first element, it uses the initial value as the current result. +// Returns initial if the aggregation function is nil. +func AggregateSlice[X any, Y any](src []X, agg func(X, Y) Y, initial Y) Y { + if agg == nil { + return initial + } + res := initial + for _, x := range src { + res = agg(x, res) + } + return res +} + +// AggregateMap takes a map, an aggregation function and an initial value. +// It applies the aggregation function to each key-value pair in the map, also passing in the current result. +// For the first key-value pair, it uses the initial value as the current result. +// Returns initial if the aggregation function is nil. +func AggregateMap[K comparable, V any, R any](src map[K]V, agg func(K, V, R) R, initial R) R { + if agg == nil { + return initial + } + res := initial + for k, v := range src { + res = agg(k, v, res) + } + return res +} diff --git a/pkg/collections/utils_test.go b/pkg/collections/utils_test.go new file mode 100644 index 0000000..1948fe9 --- /dev/null +++ b/pkg/collections/utils_test.go @@ -0,0 +1,93 @@ +package collections_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/openmcp-project/controller-utils/pkg/collections" +) + +var _ = Describe("Utils Tests", func() { + + Context("ProjectSlice", func() { + + projectFunc := func(i int) int { + return i * 2 + } + + It("should use the projection function on each element of the slice", func() { + src := []int{1, 2, 3, 4} + projected := collections.ProjectSlice(src, projectFunc) + Expect(projected).To(Equal([]int{2, 4, 6, 8})) + Expect(src).To(Equal([]int{1, 2, 3, 4}), "original slice should not be modified") + }) + + It("should return an empty slice for an empty or nil input slice", func() { + Expect(collections.ProjectSlice(nil, projectFunc)).To(BeEmpty()) + Expect(collections.ProjectSlice([]int{}, projectFunc)).To(BeEmpty()) + }) + + It("should return nil for a nil projection function", func() { + src := []int{1, 2, 3, 4} + projected := collections.ProjectSlice[int, int](src, nil) + Expect(projected).To(BeNil()) + Expect(src).To(Equal([]int{1, 2, 3, 4}), "original slice should not be modified") + }) + + }) + + Context("ProjectMapToSlice", func() { + + projectFunc := func(k string, v string) string { + return k + ":" + v + } + + It("should use the projection function on each key-value pair of the map", func() { + src := map[string]string{"a": "1", "b": "2", "c": "3"} + projected := collections.ProjectMapToSlice(src, projectFunc) + Expect(projected).To(ConsistOf("a:1", "b:2", "c:3")) + Expect(src).To(Equal(map[string]string{"a": "1", "b": "2", "c": "3"}), "original map should not be modified") + }) + + It("should return an empty slice for an empty or nil input map", func() { + Expect(collections.ProjectMapToSlice(nil, projectFunc)).To(BeEmpty()) + Expect(collections.ProjectMapToSlice(map[string]string{}, projectFunc)).To(BeEmpty()) + }) + + It("should return nil for a nil projection function", func() { + src := map[string]string{"a": "1", "b": "2", "c": "3"} + projected := collections.ProjectMapToSlice[string, string, string](src, nil) + Expect(projected).To(BeNil()) + Expect(src).To(Equal(map[string]string{"a": "1", "b": "2", "c": "3"}), "original map should not be modified") + }) + + }) + + Context("ProjectMapToMap", func() { + + projectFunc := func(k string, v string) (string, int) { + return k, len(v) + } + + It("should use the projection function on each key-value pair of the map", func() { + src := map[string]string{"a": "1", "b": "22", "c": "333"} + projected := collections.ProjectMapToMap(src, projectFunc) + Expect(projected).To(Equal(map[string]int{"a": 1, "b": 2, "c": 3})) + Expect(src).To(Equal(map[string]string{"a": "1", "b": "22", "c": "333"}), "original map should not be modified") + }) + + It("should return an empty map for an empty or nil input map", func() { + Expect(collections.ProjectMapToMap(nil, projectFunc)).To(BeEmpty()) + Expect(collections.ProjectMapToMap(map[string]string{}, projectFunc)).To(BeEmpty()) + }) + + It("should return nil for a nil projection function", func() { + src := map[string]string{"a": "1", "b": "22", "c": "333"} + projected := collections.ProjectMapToMap[string, string, string, int](src, nil) + Expect(projected).To(BeNil()) + Expect(src).To(Equal(map[string]string{"a": "1", "b": "22", "c": "333"}), "original map should not be modified") + }) + + }) + +}) diff --git a/pkg/controller/utils.go b/pkg/controller/utils.go index 44423a5..b78baa1 100644 --- a/pkg/controller/utils.go +++ b/pkg/controller/utils.go @@ -52,3 +52,29 @@ func ObjectKey(name string, maybeNamespace ...string) client.ObjectKey { Name: name, } } + +// RemoveFinalizersWithPrefix removes finalizers with a given prefix from the object and returns their suffixes. +// If the third argument is true, all finalizers with the given prefix are removed, otherwise only the first one. +// The bool return value indicates whether a finalizer was removed. +// If it is true, the slice return value holds the suffixes of all removed finalizers (will be of length 1 if removeAll is false). +// If it is false, no finalizer with the given prefix was found. The slice return value will be empty in this case. +// The logic is based on the controller-runtime's RemoveFinalizer function. +func RemoveFinalizersWithPrefix(obj client.Object, prefix string, removeAll bool) ([]string, bool) { + fins := obj.GetFinalizers() + length := len(fins) + suffixes := make([]string, 0, length) + found := false + + index := 0 + for i := range length { + if (removeAll || !found) && strings.HasPrefix(fins[i], prefix) { + suffixes = append(suffixes, strings.TrimPrefix(fins[i], prefix)) + found = true + continue + } + fins[index] = fins[i] + index++ + } + obj.SetFinalizers(fins[:index]) + return suffixes, length != index +} diff --git a/pkg/controller/utils_test.go b/pkg/controller/utils_test.go index 431615b..22f015c 100644 --- a/pkg/controller/utils_test.go +++ b/pkg/controller/utils_test.go @@ -2,49 +2,92 @@ package controller import ( "fmt" - "testing" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/openmcp-project/controller-utils/pkg/pairs" + + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/validation" ) -func TestK8sNameHash(t *testing.T) { - tt := []struct { - input []string - expHash string - }{ - { - []string{"test1"}, - "dnhq5gcrs4mzrzzsa6cujsllg3b5ahhn67fkgmrvtvxr3a2woaka", - }, - { - // check that the same string produces the same hash - []string{"test1"}, - "dnhq5gcrs4mzrzzsa6cujsllg3b5ahhn67fkgmrvtvxr3a2woaka", - }, - { - []string{"bla"}, - "jxz4h5upzsb3e7u5ileqimnhesm7c6dvzanftg2wnsmitoljm4bq", - }, - { - []string{"some other test", "this is a very, very long string"}, - "rjphpfjbmwn6qqydv6xhtmj3kxrlzepn2tpwy4okw2ypoc3nlffq", - }, - } - - for _, tc := range tt { - t.Run(fmt.Sprint(tc.input), func(t *testing.T) { - res := K8sNameHash(tc.input...) - - if res != tc.expHash { - t.Errorf("exp hash %q, got %q", tc.expHash, res) +var _ = Describe("Predicates", func() { + + Context("K8sNameHash", func() { + + testData := []pairs.Pair[*[]string, string]{ + { + Key: &[]string{"test1"}, + Value: "dnhq5gcrs4mzrzzsa6cujsllg3b5ahhn67fkgmrvtvxr3a2woaka", + }, + { + Key: &[]string{"bla"}, + Value: "jxz4h5upzsb3e7u5ileqimnhesm7c6dvzanftg2wnsmitoljm4bq", + }, + { + Key: &[]string{"some other test", "this is a very, very long string"}, + Value: "rjphpfjbmwn6qqydv6xhtmj3kxrlzepn2tpwy4okw2ypoc3nlffq", + }, + } + + It("should generate the same hash for the same input value", func() { + for _, p := range testData { + for range 5 { + res := K8sNameHash(*p.Key...) + Expect(res).To(Equal(p.Value)) + } } + }) + + It("should generate different hashes for different input values", func() { + res1 := K8sNameHash(*testData[0].Key...) + res2 := K8sNameHash(*testData[1].Key...) + res3 := K8sNameHash(*testData[2].Key...) + Expect(res1).NotTo(Equal(res2)) + Expect(res1).NotTo(Equal(res3)) + Expect(res2).NotTo(Equal(res3)) + }) - // ensure the result is a valid DNS1123Subdomain - if errs := validation.IsDNS1123Subdomain(res); errs != nil { - t.Errorf("value %q is invalid: %v", res, errs) + It("should generate a valid DNS1123Subdomain", func() { + for _, p := range testData { + res := K8sNameHash(*p.Key...) + errs := validation.IsDNS1123Subdomain(res) + Expect(errs).To(BeEmpty(), fmt.Sprintf("value %q is invalid: %v", res, errs)) } + }) + + }) + + Context("RemoveFinalizersWithPrefix", func() { + + It("should only remove the first finalizer with the given prefix", func() { + ns := &corev1.Namespace{} + ns.SetFinalizers([]string{"foo/bar", "baz/qux", "foo/baz"}) + suffix, removed := RemoveFinalizersWithPrefix(ns, "foo/", false) + Expect(removed).To(BeTrue()) + Expect(suffix).To(ConsistOf("bar")) + Expect(ns.GetFinalizers()).To(Equal([]string{"baz/qux", "foo/baz"}), "should remove only the first matching finalizer") + }) + + It("should remove all finalizers with the given prefix", func() { + ns := &corev1.Namespace{} + ns.SetFinalizers([]string{"foo/bar", "baz/qux", "foo/baz"}) + suffix, removed := RemoveFinalizersWithPrefix(ns, "foo/", true) + Expect(removed).To(BeTrue()) + Expect(suffix).To(ConsistOf("bar", "baz")) + Expect(ns.GetFinalizers()).To(Equal([]string{"baz/qux"}), "should remove all matching finalizers") + }) + It("should return false if no finalizer with the given prefix exists", func() { + ns := &corev1.Namespace{} + ns.SetFinalizers([]string{"foo/bar", "baz/qux"}) + suffix, removed := RemoveFinalizersWithPrefix(ns, "nonexistent/", false) + Expect(removed).To(BeFalse()) + Expect(suffix).To(BeEmpty()) + Expect(ns.GetFinalizers()).To(Equal([]string{"foo/bar", "baz/qux"}), "should not modify finalizers if no match is found") }) - } -} + }) + +}) diff --git a/pkg/pairs/pairs_test.go b/pkg/pairs/pairs_test.go index 0c68421..2e7e6c5 100644 --- a/pkg/pairs/pairs_test.go +++ b/pkg/pairs/pairs_test.go @@ -14,7 +14,7 @@ import ( func TestConditions(t *testing.T) { RegisterFailHandler(Fail) - RunSpecs(t, "ClusterAccess Test Suite") + RunSpecs(t, "Pairs Test Suite") } type comparableIntAlias int diff --git a/pkg/retry/retry.go b/pkg/retry/retry.go new file mode 100644 index 0000000..3edda49 --- /dev/null +++ b/pkg/retry/retry.go @@ -0,0 +1,324 @@ +package retry + +import ( + "context" + "reflect" + "time" + + "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" +) + +type Client struct { + internal client.Client + interval time.Duration + backoffMultiplier float64 + maxAttempts int + timeout time.Duration +} + +// NewRetryingClient returns a retry.Client that implements client.Client, but retries each operation that can fail with the specified parameters. +// Returns nil if the provided client is nil. +// The default parameters are: +// - interval: 100 milliseconds +// - backoffMultiplier: 1.0 (no backoff) +// - maxAttempts: 0 (no limit on attempts) +// - timeout: 1 second (timeout for retries) +// Use the builder-style With... methods to adapt the parameters. +func NewRetryingClient(c client.Client) *Client { + if c == nil { + return nil + } + return &Client{ + internal: c, + interval: 100 * time.Millisecond, // default retry interval + backoffMultiplier: 1.0, // default backoff multiplier + maxAttempts: 0, // default max retries + timeout: 1 * time.Second, // default timeout for retries + } +} + +var _ client.Client = &Client{} + +///////////// +// GETTERS // +///////////// + +// Interval returns the configured retry interval. +func (rc *Client) Interval() time.Duration { + return rc.interval +} + +// BackoffMultiplier returns the configured backoff multiplier for retries. +func (rc *Client) BackoffMultiplier() float64 { + return rc.backoffMultiplier +} + +// MaxRetries returns the configured maximum number of retries. +func (rc *Client) MaxRetries() int { + return rc.maxAttempts +} + +// Timeout returns the configured timeout for retries. +func (rc *Client) Timeout() time.Duration { + return rc.timeout +} + +///////////// +// SETTERS // +///////////// + +// WithInterval sets the retry interval for the Client. +// Default is 100 milliseconds. +// Noop if the interval is less than or equal to 0. +// It returns the Client for chaining. +func (rc *Client) WithInterval(interval time.Duration) *Client { + if interval > 0 { + rc.interval = interval + } + return rc +} + +// WithBackoffMultiplier sets the backoff multiplier for the Client. +// After each retry, the configured interval is multiplied by this factor. +// Setting it to a value less than 1 will default it to 1. +// Default is 1.0, meaning no backoff. +// Noop if the multiplier is less than 1. +// It returns the Client for chaining. +func (rc *Client) WithBackoffMultiplier(multiplier float64) *Client { + if multiplier >= 1 { + rc.backoffMultiplier = multiplier + } + return rc +} + +// WithMaxAttempts sets the maximum number of attempts for the Client. +// If set to 0, it will retry indefinitely until the timeout is reached. +// Default is 0, meaning no limit on attempts. +// Noop if the maxAttempts is less than 0. +// It returns the Client for chaining. +func (rc *Client) WithMaxAttempts(maxAttempts int) *Client { + if maxAttempts >= 0 { + rc.maxAttempts = maxAttempts + } + return rc +} + +// WithTimeout sets the timeout for retries in the Client. +// If set to 0, there is no timeout and it will retry until the maximum number of retries is reached. +// Default is 1 second. +// Noop if the timeout is less than 0. +// It returns the Client for chaining. +func (rc *Client) WithTimeout(timeout time.Duration) *Client { + if timeout >= 0 { + rc.timeout = timeout + } + return rc +} + +/////////////////////////// +// CLIENT IMPLEMENTATION // +/////////////////////////// + +type operation struct { + parent *Client + interval time.Duration + attempts int + startTime time.Time + method reflect.Value + args []reflect.Value +} + +func (rc *Client) newOperation(method reflect.Value, args ...any) *operation { + op := &operation{ + parent: rc, + interval: rc.interval, + attempts: 0, + startTime: time.Now(), + method: method, + } + if method.Type().IsVariadic() { + argCountWithoutVariadic := len(args) - 1 + last := args[argCountWithoutVariadic] + lastVal := reflect.ValueOf(last) + argCountVariadic := lastVal.Len() + op.args = make([]reflect.Value, argCountWithoutVariadic+argCountVariadic) + for i, arg := range args[:argCountWithoutVariadic] { + op.args[i] = reflect.ValueOf(arg) + } + for i := range argCountVariadic { + op.args[argCountWithoutVariadic+i] = lastVal.Index(i) + } + } else { + op.args = make([]reflect.Value, len(args)) + for i, arg := range args { + op.args[i] = reflect.ValueOf(arg) + } + } + return op +} + +// try attempts the operation. +// The first return value indicates success (true) or failure (false). +// The second return value is the duration to wait before the next retry. +// +// If it is 0, no retry is needed. +// This can be because the operation succeeded, or because the timeout or retry limit was reached. +// +// The third return value contains the return values of the operation. +func (op *operation) try() (bool, time.Duration, []reflect.Value) { + res := op.method.Call(op.args) + + // check for success by converting the last return value to an error + success := true + if len(res) > 0 { + if err, ok := res[len(res)-1].Interface().(error); ok && err != nil { + success = false + } + } + + // if the operation succeeded, return true and no retry + if success { + return true, 0, res + } + + // if the operation failed, check if we should retry + op.attempts++ + retryAfter := op.interval + op.interval = time.Duration(float64(op.interval) * op.parent.backoffMultiplier) + if (op.parent.maxAttempts > 0 && op.attempts >= op.parent.maxAttempts) || (op.parent.timeout > 0 && time.Now().Add(retryAfter).After(op.startTime.Add(op.parent.timeout))) { + // if we reached the maximum number of retries or the next retry would exceed the timeout, return false and no retry + return false, 0, res + } + + return false, retryAfter, res +} + +// retry executes the given method with the provided arguments, retrying on failure. +func (rc *Client) retry(method reflect.Value, args ...any) []reflect.Value { + op := rc.newOperation(method, args...) + var ctx context.Context + if len(args) > 0 { + if ctxArg, ok := args[0].(context.Context); ok { + ctx = ctxArg + } + } + if ctx == nil { + ctx = context.Background() + } + if rc.Timeout() > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, op.startTime.Add(rc.timeout)) + defer cancel() + } + interruptedOrTimeouted := ctx.Done() + success, retryAfter, res := op.try() + for !success && retryAfter > 0 { + opCtx, opCancel := context.WithTimeout(ctx, retryAfter) + expired := opCtx.Done() + select { + case <-interruptedOrTimeouted: + retryAfter = 0 // stop retrying if the context was cancelled + case <-expired: + success, retryAfter, res = op.try() + } + opCancel() + } + return res +} + +func errOrNil(val reflect.Value) error { + if val.IsNil() { + return nil + } + return val.Interface().(error) +} + +// CreateOrUpdate wraps the controllerutil.CreateOrUpdate function and retries it on failure. +func (rc *Client) CreateOrUpdate(ctx context.Context, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + res := rc.retry(reflect.ValueOf(controllerutil.CreateOrUpdate), ctx, rc.internal, obj, f) + return res[0].Interface().(controllerutil.OperationResult), errOrNil(res[1]) +} + +// CreateOrPatch wraps the controllerutil.CreateOrPatch function and retries it on failure. +func (rc *Client) CreateOrPatch(ctx context.Context, obj client.Object, f controllerutil.MutateFn) (controllerutil.OperationResult, error) { + res := rc.retry(reflect.ValueOf(controllerutil.CreateOrPatch), ctx, rc.internal, obj, f) + return res[0].Interface().(controllerutil.OperationResult), errOrNil(res[1]) +} + +// Create wraps the client's Create method and retries it on failure. +func (rc *Client) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + res := rc.retry(reflect.ValueOf(rc.internal.Create), ctx, obj, opts) + return errOrNil(res[0]) +} + +// Delete wraps the client's Delete method and retries it on failure. +func (rc *Client) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + res := rc.retry(reflect.ValueOf(rc.internal.Delete), ctx, obj, opts) + return errOrNil(res[0]) +} + +// DeleteAllOf wraps the client's DeleteAllOf method and retries it on failure. +func (rc *Client) DeleteAllOf(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption) error { + res := rc.retry(reflect.ValueOf(rc.internal.DeleteAllOf), ctx, obj, opts) + return errOrNil(res[0]) +} + +// Get wraps the client's Get method and retries it on failure. +func (rc *Client) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + res := rc.retry(reflect.ValueOf(rc.internal.Get), ctx, key, obj, opts) + return errOrNil(res[0]) +} + +// List wraps the client's List method and retries it on failure. +func (rc *Client) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + res := rc.retry(reflect.ValueOf(rc.internal.List), ctx, list, opts) + return errOrNil(res[0]) +} + +// Patch wraps the client's Patch method and retries it on failure. +func (rc *Client) Patch(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error { + res := rc.retry(reflect.ValueOf(rc.internal.Patch), ctx, obj, patch, opts) + return errOrNil(res[0]) +} + +// Update wraps the client's Update method and retries it on failure. +func (rc *Client) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + res := rc.retry(reflect.ValueOf(rc.internal.Update), ctx, obj, opts) + return errOrNil(res[0]) +} + +// GroupVersionKindFor wraps the client's GroupVersionKindFor method and retries it on failure. +func (rc *Client) GroupVersionKindFor(obj runtime.Object) (schema.GroupVersionKind, error) { + res := rc.retry(reflect.ValueOf(rc.internal.GroupVersionKindFor), obj) + return res[0].Interface().(schema.GroupVersionKind), errOrNil(res[1]) +} + +// IsObjectNamespaced wraps the client's IsObjectNamespaced method and retries it on failure. +func (rc *Client) IsObjectNamespaced(obj runtime.Object) (bool, error) { + res := rc.retry(reflect.ValueOf(rc.internal.IsObjectNamespaced), obj) + return res[0].Interface().(bool), errOrNil(res[1]) +} + +// RESTMapper calls the internal client's RESTMapper method. +func (rc *Client) RESTMapper() meta.RESTMapper { + return rc.internal.RESTMapper() +} + +// Scheme calls the internal client's Scheme method. +func (rc *Client) Scheme() *runtime.Scheme { + return rc.internal.Scheme() +} + +// Status calls the internal client's Status method. +func (rc *Client) Status() client.SubResourceWriter { + return rc.internal.Status() +} + +// SubResource calls the internal client's SubResource method. +func (rc *Client) SubResource(subResource string) client.SubResourceClient { + return rc.internal.SubResource(subResource) +} diff --git a/pkg/retry/retry_test.go b/pkg/retry/retry_test.go new file mode 100644 index 0000000..d163307 --- /dev/null +++ b/pkg/retry/retry_test.go @@ -0,0 +1,407 @@ +package retry_test + +import ( + "context" + "fmt" + "testing" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + . "github.com/onsi/gomega/gstruct" + + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + + "github.com/openmcp-project/controller-utils/pkg/retry" + testutils "github.com/openmcp-project/controller-utils/pkg/testing" +) + +func TestConditions(t *testing.T) { + RegisterFailHandler(Fail) + + RunSpecs(t, "Retry Test Suite") +} + +// mockControl is a helper struct to control the behavior of the fake client. +// Each attempt will increase the 'attempts' counter. +// Returns a mockError if +// - 'fail' is less than 0 +// - 'fail' is greater than 0 and the number of attempts is less than or equal to 'fail' +// Returns nil otherwise. +type mockControl struct { + fail int + attempts int +} + +func (mc *mockControl) reset(failCount int) { + mc.fail = failCount + mc.attempts = 0 +} + +func (mc *mockControl) try() error { + mc.attempts++ + if mc.fail < 0 || (mc.fail > 0 && mc.attempts <= mc.fail) { + return errMock + } + return nil +} + +var errMock = fmt.Errorf("mock error") + +func defaultTestSetup() (*testutils.Environment, *mockControl) { + mc := &mockControl{} + return testutils.NewEnvironmentBuilder(). + WithFakeClient(nil). + WithFakeClientBuilderCall("WithInterceptorFuncs", interceptor.Funcs{ + Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if err := mc.try(); err != nil { + return err + } + return client.Get(ctx, key, obj, opts...) + }, + List: func(ctx context.Context, client client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + if err := mc.try(); err != nil { + return err + } + return client.List(ctx, list, opts...) + }, + Create: func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + if err := mc.try(); err != nil { + return err + } + return client.Create(ctx, obj, opts...) + }, + Delete: func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.DeleteOption) error { + if err := mc.try(); err != nil { + return err + } + return client.Delete(ctx, obj, opts...) + }, + DeleteAllOf: func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.DeleteAllOfOption) error { + if err := mc.try(); err != nil { + return err + } + return client.DeleteAllOf(ctx, obj, opts...) + }, + Update: func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.UpdateOption) error { + if err := mc.try(); err != nil { + return err + } + return client.Update(ctx, obj, opts...) + }, + Patch: func(ctx context.Context, client client.WithWatch, obj client.Object, patch client.Patch, opts ...client.PatchOption) error { + if err := mc.try(); err != nil { + return err + } + return client.Patch(ctx, obj, patch, opts...) + }, + }). + Build(), mc +} + +var _ = Describe("Client", func() { + + It("should not retry if the operation succeeds immediately", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()) + + // create a Namespace + ns := &corev1.Namespace{} + ns.Name = "test" + mc.reset(0) + Expect(c.Create(env.Ctx, ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + + // get the Namespace + mc.reset(0) + Expect(c.Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + + // list Namespaces + mc.reset(0) + nsList := &corev1.NamespaceList{} + Expect(c.List(env.Ctx, nsList)).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + Expect(nsList.Items).To(ContainElement(MatchFields(IgnoreExtras, Fields{ + "ObjectMeta": MatchFields(IgnoreExtras, Fields{ + "Name": Equal("test"), + }), + }))) + + // update the Namespace + mc.reset(0) + ns.Labels = map[string]string{"test": "label"} + Expect(c.Update(env.Ctx, ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + + // patch the Namespace + mc.reset(0) + old := ns.DeepCopy() + ns.Labels = nil + Expect(c.Patch(env.Ctx, ns, client.MergeFrom(old))).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + + // delete the Namespace + mc.reset(0) + Expect(c.Delete(env.Ctx, ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + + // delete all Namespaces + mc.reset(0) + Expect(c.DeleteAllOf(env.Ctx, &corev1.Namespace{})).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + }) + + It("should retry if the operation does not succeed immediately", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()).WithMaxAttempts(5).WithTimeout(0) + + // create a Namespace + ns := &corev1.Namespace{} + ns.Name = "test" + mc.reset(2) + Expect(c.Create(env.Ctx, ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(3)) + + // get the Namespace + mc.reset(2) + Expect(c.Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(3)) + + // list Namespaces + mc.reset(2) + nsList := &corev1.NamespaceList{} + Expect(c.List(env.Ctx, nsList)).To(Succeed()) + Expect(mc.attempts).To(Equal(3)) + Expect(nsList.Items).To(ContainElement(MatchFields(IgnoreExtras, Fields{ + "ObjectMeta": MatchFields(IgnoreExtras, Fields{ + "Name": Equal("test"), + }), + }))) + + // update the Namespace + mc.reset(2) + ns.Labels = map[string]string{"test": "label"} + Expect(c.Update(env.Ctx, ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(3)) + + // patch the Namespace + mc.reset(2) + old := ns.DeepCopy() + ns.Labels = nil + Expect(c.Patch(env.Ctx, ns, client.MergeFrom(old))).To(Succeed()) + Expect(mc.attempts).To(Equal(3)) + + // delete the Namespace + mc.reset(2) + Expect(c.Delete(env.Ctx, ns)).To(Succeed()) + Expect(mc.attempts).To(Equal(3)) + + // delete all Namespaces + mc.reset(2) + Expect(c.DeleteAllOf(env.Ctx, &corev1.Namespace{})).To(Succeed()) + Expect(mc.attempts).To(Equal(3)) + }) + + It("should not retry more often than configured", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()).WithMaxAttempts(5).WithTimeout(0) + + // create a Namespace + ns := &corev1.Namespace{} + ns.Name = "test" + mc.reset(-1) + Expect(c.Create(env.Ctx, ns)).ToNot(Succeed()) + Expect(mc.attempts).To(Equal(5)) + + // get the Namespace + mc.reset(-1) + Expect(c.Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).ToNot(Succeed()) + Expect(mc.attempts).To(Equal(5)) + + // list Namespaces + mc.reset(-1) + nsList := &corev1.NamespaceList{} + Expect(c.List(env.Ctx, nsList)).ToNot(Succeed()) + Expect(mc.attempts).To(Equal(5)) + + // update the Namespace + mc.reset(-1) + ns.Labels = map[string]string{"test": "label"} + Expect(c.Update(env.Ctx, ns)).ToNot(Succeed()) + Expect(mc.attempts).To(Equal(5)) + + // patch the Namespace + mc.reset(-1) + old := ns.DeepCopy() + ns.Labels = nil + Expect(c.Patch(env.Ctx, ns, client.MergeFrom(old))).ToNot(Succeed()) + Expect(mc.attempts).To(Equal(5)) + + // delete the Namespace + mc.reset(-1) + Expect(c.Delete(env.Ctx, ns)).ToNot(Succeed()) + Expect(mc.attempts).To(Equal(5)) + + // delete all Namespaces + mc.reset(-1) + Expect(c.DeleteAllOf(env.Ctx, &corev1.Namespace{})).ToNot(Succeed()) + Expect(mc.attempts).To(Equal(5)) + }) + + It("should not retry longer than configured", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()).WithMaxAttempts(0).WithTimeout(500 * time.Millisecond) + + // for performance reasons, let's test this for Create only + ns := &corev1.Namespace{} + ns.Name = "test" + mc.reset(-1) + now := time.Now() + timeoutCtx, cancel := context.WithTimeout(env.Ctx, 1*time.Second) + defer cancel() + Expect(c.Create(timeoutCtx, ns)).ToNot(Succeed()) + after := time.Now() + Expect(after.Sub(now)).To(BeNumerically(">=", 400*time.Millisecond)) + Expect(after.Sub(now)).To(BeNumerically("<", 1*time.Second)) + Expect(mc.attempts).To(BeNumerically(">=", 4)) + Expect(mc.attempts).To(BeNumerically("<=", 5)) + }) + + It("should apply the backoff multiplier correctly", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()).WithMaxAttempts(0).WithTimeout(500 * time.Millisecond).WithBackoffMultiplier(3.0) + + // for performance reasons, let's test this for Create only + ns := &corev1.Namespace{} + ns.Name = "test" + mc.reset(-1) + now := time.Now() + timeoutCtx, cancel := context.WithTimeout(env.Ctx, 1*time.Second) + defer cancel() + Expect(c.Create(timeoutCtx, ns)).ToNot(Succeed()) + after := time.Now() + Expect(after.Sub(now)).To(BeNumerically(">=", 400*time.Millisecond)) + Expect(after.Sub(now)).To(BeNumerically("<", 1*time.Second)) + Expect(mc.attempts).To(BeNumerically("==", 3)) + }) + + It("should abort if the context is canceled", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()).WithMaxAttempts(0).WithTimeout(500 * time.Millisecond) + + // for performance reasons, let's test this for Create only + ns := &corev1.Namespace{} + ns.Name = "test" + mc.reset(-1) + now := time.Now() + timeoutCtx, cancel := context.WithTimeout(env.Ctx, 200*time.Millisecond) + defer cancel() + Expect(c.Create(timeoutCtx, ns)).ToNot(Succeed()) + after := time.Now() + Expect(after.Sub(now)).To(BeNumerically("<", 300*time.Millisecond)) + Expect(mc.attempts).To(BeNumerically("<=", 3)) + }) + + It("should pass the arguments through correctly", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()) + + // for performance reasons, let's test this for Create only + s1 := &corev1.Secret{} + s1.Name = "test" + s1.Namespace = "foo" + Expect(env.Client().Create(env.Ctx, s1)).To(Succeed()) + s2 := &corev1.Secret{} + s2.Name = "test" + s2.Namespace = "bar" + Expect(env.Client().Create(env.Ctx, s2)).To(Succeed()) + mc.reset(0) + l1 := &corev1.SecretList{} + Expect(c.List(env.Ctx, l1)).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + Expect(l1.Items).To(ConsistOf( + MatchFields(IgnoreExtras, Fields{ + "ObjectMeta": MatchFields(IgnoreExtras, Fields{ + "Name": Equal("test"), + "Namespace": Equal("foo"), + }), + }), + MatchFields(IgnoreExtras, Fields{ + "ObjectMeta": MatchFields(IgnoreExtras, Fields{ + "Name": Equal("test"), + "Namespace": Equal("bar"), + }), + }), + )) + mc.reset(0) + l2 := &corev1.SecretList{} + Expect(c.List(env.Ctx, l2, client.InNamespace("foo"))).To(Succeed()) + Expect(mc.attempts).To(Equal(1)) + Expect(l2.Items).To(ConsistOf( + MatchFields(IgnoreExtras, Fields{ + "ObjectMeta": MatchFields(IgnoreExtras, Fields{ + "Name": Equal("test"), + "Namespace": Equal("foo"), + }), + }), + )) + }) + + It("should correctly handle CreateOrUpdate and CreateOrPatch", func() { + env, mc := defaultTestSetup() + c := retry.NewRetryingClient(env.Client()).WithMaxAttempts(5).WithTimeout(0) + + // create or update namespace + // we cannot check mc.attempts here, because CreateOrUpdate calls multiple methods on the client internally + ns := &corev1.Namespace{} + ns.Name = "test" + mc.reset(0) + Expect(c.CreateOrUpdate(env.Ctx, ns, func() error { + return nil + })) + Expect(env.Client().Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + mc.reset(0) + Expect(c.CreateOrUpdate(env.Ctx, ns, func() error { + ns.Labels = map[string]string{"test": "label"} + return nil + })) + Expect(env.Client().Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + Expect(ns.Labels).To(HaveKeyWithValue("test", "label")) + mc.reset(2) + Expect(c.CreateOrUpdate(env.Ctx, ns, func() error { + ns.Labels = map[string]string{"test2": "label2"} + return nil + })) + Expect(env.Client().Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + Expect(ns.Labels).To(HaveKeyWithValue("test2", "label2")) + Expect(env.Client().Delete(env.Ctx, ns)).To(Succeed()) + + // create or patch namespace + ns = &corev1.Namespace{} + ns.Name = "test" + mc.reset(0) + Expect(c.CreateOrPatch(env.Ctx, ns, func() error { + return nil + })) + Expect(env.Client().Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + mc.reset(0) + Expect(c.CreateOrPatch(env.Ctx, ns, func() error { + ns.Labels = map[string]string{"test": "label"} + return nil + })) + Expect(env.Client().Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + Expect(ns.Labels).To(HaveKeyWithValue("test", "label")) + mc.reset(2) + Expect(c.CreateOrUpdate(env.Ctx, ns, func() error { + ns.Labels = map[string]string{"test2": "label2"} + return nil + })) + Expect(env.Client().Get(env.Ctx, client.ObjectKeyFromObject(ns), ns)).To(Succeed()) + Expect(ns.Labels).To(HaveKeyWithValue("test2", "label2")) + Expect(env.Client().Delete(env.Ctx, ns)).To(Succeed()) + }) + +}) diff --git a/pkg/testing/complex_environment.go b/pkg/testing/complex_environment.go index 2b42e4c..c44f2e3 100644 --- a/pkg/testing/complex_environment.go +++ b/pkg/testing/complex_environment.go @@ -9,10 +9,12 @@ import ( "github.com/onsi/gomega" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/uuid" clientgoscheme "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" "sigs.k8s.io/controller-runtime/pkg/reconcile" "github.com/openmcp-project/controller-utils/pkg/logging" @@ -121,6 +123,7 @@ type ComplexEnvironmentBuilder struct { ClusterInitObjectPaths map[string][]string ClientCreationCallbacks map[string][]func(client.Client) loggerIsSet bool + InjectUIDs map[string]bool } type ClusterEnvironment struct { @@ -163,6 +166,7 @@ func NewComplexEnvironmentBuilder() *ComplexEnvironmentBuilder { ClusterStatusObjects: map[string][]client.Object{}, ClusterInitObjectPaths: map[string][]string{}, ClientCreationCallbacks: map[string][]func(client.Client){}, + InjectUIDs: map[string]bool{}, } } @@ -264,6 +268,16 @@ func (eb *ComplexEnvironmentBuilder) WithAfterClientCreationCallback(name string return eb } +// WithUIDs enables UID injection for the specified cluster. +// All objects that are initially loaded or afterwards created via the client's 'Create' method will have a random UID injected, if they do not already have one. +// Note that this function registers an interceptor function, which will be overwritten if 'WithFakeClientBuilderCall(..., "WithInterceptorFuncs", ...)' is also called. +// This would lead to newly created objects not having a UID injected. +// To avoid this, pass 'InjectUIDOnObjectCreation(...)' into the interceptor.Funcs' Create field. The argument allows to inject your own additional Create logic, if desired. +func (eb *ComplexEnvironmentBuilder) WithUIDs(name string) *ComplexEnvironmentBuilder { + eb.InjectUIDs[name] = true + return eb +} + // WithFakeClientBuilderCall allows to inject method calls to fake.ClientBuilder when the fake clients are created during Build(). // The fake clients are usually created using WithScheme(...).WithObjects(...).WithStatusSubresource(...).Build(). // This function allows to inject additional method calls. It is only required for advanced use-cases. @@ -284,6 +298,8 @@ func (eb *ComplexEnvironmentBuilder) WithFakeClientBuilderCall(name string, meth // Build constructs the environment from the builder. // Note that this function panics instead of throwing an error, // as it is intended to be used in tests, where all information is static anyway. +// +//nolint:gocyclo func (eb *ComplexEnvironmentBuilder) Build() *ComplexEnvironment { res := eb.internal @@ -335,6 +351,18 @@ func (eb *ComplexEnvironmentBuilder) Build() *ComplexEnvironment { if len(eb.ClusterInitObjects) > 0 { objs = append(objs, eb.ClusterInitObjects[name]...) } + if eb.InjectUIDs[name] { + // ensure that objects have a uid + for _, obj := range objs { + if obj.GetUID() == "" { + // set a random UID if not already set + obj.SetUID(uuid.NewUUID()) + } + } + fcb.WithInterceptorFuncs(interceptor.Funcs{ + Create: InjectUIDOnObjectCreation(nil), + }) + } statusObjs := []client.Object{} statusObjs = append(statusObjs, objs...) statusObjs = append(statusObjs, eb.ClusterStatusObjects[name]...) @@ -396,3 +424,19 @@ func (eb *ComplexEnvironmentBuilder) Build() *ComplexEnvironment { return res } + +// InjectUIDOnObjectCreation returns an interceptor function for Create which injects a random UID into the object, if it does not already have one. +// If additionalLogic is nil, the object is created regularly afterwards. +// Otherwise, additionalLogic is called. +// If you called 'WithUIDs(...)' on the ComplexEnvironmentBuilder AND 'WithFakeClientBuilderCall(..., "WithInterceptorFuncs", ...)', then you need to pass this function into the interceptor.Funcs' Create field, optionally adding your own creation logic via additionalLogic. +func InjectUIDOnObjectCreation(additionalLogic func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.CreateOption) error) func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + return func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + if obj.GetUID() == "" { + obj.SetUID(uuid.NewUUID()) + } + if additionalLogic != nil { + return additionalLogic(ctx, client, obj, opts...) + } + return client.Create(ctx, obj, opts...) + } +} diff --git a/pkg/testing/environment.go b/pkg/testing/environment.go index ace6cb1..f0ff9c7 100644 --- a/pkg/testing/environment.go +++ b/pkg/testing/environment.go @@ -153,6 +153,16 @@ func (eb *EnvironmentBuilder) WithAfterClientCreationCallback(callback func(clie return eb } +// WithUIDs enables UID injection. +// All objects that are initially loaded or afterwards created via the client's 'Create' method will have a random UID injected, if they do not already have one. +// Note that this function registers an interceptor function, which will be overwritten if 'WithFakeClientBuilderCall("WithInterceptorFuncs", ...)' is also called. +// This would lead to newly created objects not having a UID injected. +// To avoid this, pass 'InjectUIDOnObjectCreation(...)' into the interceptor.Funcs' Create field. The argument allows to inject your own additional Create logic, if desired. +func (eb *EnvironmentBuilder) WithUIDs() *EnvironmentBuilder { + eb.ComplexEnvironmentBuilder.WithUIDs(SimpleEnvironmentDefaultKey) + return eb +} + // WithFakeClientBuilderCall allows to inject method calls to fake.ClientBuilder when the fake client is created during Build(). // The fake client is usually created using WithScheme(...).WithObjects(...).WithStatusSubresource(...).Build(). // This function allows to inject additional method calls. It is only required for advanced use-cases.