diff --git a/context.go b/context.go index fbcbdc7..0797ab2 100644 --- a/context.go +++ b/context.go @@ -36,7 +36,7 @@ func StoreInContext(ctx context.Context, key, value interface{}) { // LoadFromContext returns the value from the given context with the given key. // It returns the value and true, or nil and false if the key doesn't exist. -// It may return nil and true if the key exists, but the value actually is nil. +// It returns nil and true if the key exists and the value actually is nil. // Use StoreInContext to store values. // // Note: This method is thread-safe, but panics if the ctx has not been set up with MutableContext first. @@ -49,3 +49,14 @@ func LoadFromContext(ctx context.Context, key interface{}) (interface{}, bool) { val, found := mp.Load(key) return val, found } + +// MustLoadFromContext is similar to LoadFromContext, except it doesn't return a bool to indicate whether the key exists. +// It returns nil if either the key doesn't exist, or if the value of the key is nil. +// Use StoreInContext to store values. +// +// Note: This is a convenience method for cases when it's not relevant whether the key is existing or not. +// Note: This method is thread-safe, but panics if the ctx has not been set up with MutableContext first. +func MustLoadFromContext(ctx context.Context, key interface{}) interface{} { + val, _ := LoadFromContext(ctx, key) + return val +} diff --git a/context_test.go b/context_test.go index 9775ff0..d4a8f45 100644 --- a/context_test.go +++ b/context_test.go @@ -62,6 +62,26 @@ func TestMutableContextRepeated(t *testing.T) { assert.Equal(t, result, repeated) } +func TestMustLoadFromContext(t *testing.T) { + t.Run("KeyExistsWithNil", func(t *testing.T) { + ctx := MutableContext(context.Background()) + StoreInContext(ctx, "key", nil) + result := MustLoadFromContext(ctx, "key") + assert.Nil(t, result) + }) + t.Run("KeyDoesntExist", func(t *testing.T) { + ctx := MutableContext(context.Background()) + result := MustLoadFromContext(ctx, "key") + assert.Nil(t, result) + }) + t.Run("KeyExistsWithValue", func(t *testing.T) { + ctx := MutableContext(context.Background()) + StoreInContext(ctx, "key", "value") + result := MustLoadFromContext(ctx, "key") + assert.Equal(t, "value", result) + }) +} + func ExampleMutableContext() { ctx := MutableContext(context.Background()) p := NewPipeline().WithSteps(