diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/connection_manager.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/connection_manager.go new file mode 100644 index 0000000000..08076e8578 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/connection_manager.go @@ -0,0 +1,54 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" +) + +// ConnectionManager is an autogenerated mock type for the ConnectionManager type +type ConnectionManager struct { + mock.Mock +} + +type ConnectionManager_Get struct { + *mock.Call +} + +func (_m ConnectionManager_Get) Return(_a0 core.Connection, _a1 error) *ConnectionManager_Get { + return &ConnectionManager_Get{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ConnectionManager) OnGet(ctx context.Context, key string) *ConnectionManager_Get { + c_call := _m.On("Get", ctx, key) + return &ConnectionManager_Get{Call: c_call} +} + +func (_m *ConnectionManager) OnGetMatch(matchers ...interface{}) *ConnectionManager_Get { + c_call := _m.On("Get", matchers...) + return &ConnectionManager_Get{Call: c_call} +} + +// Get provides a mock function with given fields: ctx, key +func (_m *ConnectionManager) Get(ctx context.Context, key string) (core.Connection, error) { + ret := _m.Called(ctx, key) + + var r0 core.Connection + if rf, ok := ret.Get(0).(func(context.Context, string) core.Connection); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Get(0).(core.Connection) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_context.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_context.go index e7ed1d7e5f..e8795e7c1a 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_context.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_context.go @@ -18,11 +18,6 @@ type TaskExecutionContext struct { mock.Mock } -func (_m *TaskExecutionContext) ConnectionManager() core.ConnectionManager { - //TODO implement me - panic("implement me") -} - type TaskExecutionContext_Catalog struct { *mock.Call } @@ -57,6 +52,40 @@ func (_m *TaskExecutionContext) Catalog() catalog.AsyncClient { return r0 } +type TaskExecutionContext_ConnectionManager struct { + *mock.Call +} + +func (_m TaskExecutionContext_ConnectionManager) Return(_a0 core.ConnectionManager) *TaskExecutionContext_ConnectionManager { + return &TaskExecutionContext_ConnectionManager{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionContext) OnConnectionManager() *TaskExecutionContext_ConnectionManager { + c_call := _m.On("ConnectionManager") + return &TaskExecutionContext_ConnectionManager{Call: c_call} +} + +func (_m *TaskExecutionContext) OnConnectionManagerMatch(matchers ...interface{}) *TaskExecutionContext_ConnectionManager { + c_call := _m.On("ConnectionManager", matchers...) + return &TaskExecutionContext_ConnectionManager{Call: c_call} +} + +// ConnectionManager provides a mock function with given fields: +func (_m *TaskExecutionContext) ConnectionManager() core.ConnectionManager { + ret := _m.Called() + + var r0 core.ConnectionManager + if rf, ok := ret.Get(0).(func() core.ConnectionManager); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(core.ConnectionManager) + } + } + + return r0 +} + type TaskExecutionContext_DataStore struct { *mock.Call } diff --git a/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go b/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go index a35782049b..aff5397922 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go @@ -2,8 +2,6 @@ package core import "context" -//go:generate mockery -all -output=./mocks -case=underscore - type SecretManager interface { Get(ctx context.Context, key string) (string, error) } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 6fb3828c0c..6ef3067057 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -104,6 +104,8 @@ func TestEndToEnd(t *testing.T) { }, }, } + template.SecurityContext = &flyteIdlCore.SecurityContext{Connection: "openai"} + expectedOutputs, err := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) assert.NoError(t, err) phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, expectedOutputs, nil, iter) diff --git a/flyteplugins/tests/end_to_end.go b/flyteplugins/tests/end_to_end.go index 1a55b52c0c..e405431acf 100644 --- a/flyteplugins/tests/end_to_end.go +++ b/flyteplugins/tests/end_to_end.go @@ -231,6 +231,10 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i secretManager := &coreMocks.SecretManager{} secretManager.OnGet(ctx, mock.Anything).Return("fake-token", nil) + connection := idlCore.Connection{Secrets: map[string]string{"OPENAI_API_KEY": "123"}} + connectionManager := &coreMocks.ConnectionManager{} + connectionManager.OnGet(ctx, mock.Anything).Return(connection, nil) + tCtx := &coreMocks.TaskExecutionContext{} tCtx.OnInputReader().Return(inputReader) tCtx.OnTaskRefreshIndicator().Return(func(ctx context.Context) {}) @@ -245,6 +249,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i tCtx.OnResourceManager().Return(resourceManager) tCtx.OnMaxDatasetSizeBytes().Return(1000000) tCtx.OnSecretManager().Return(secretManager) + tCtx.OnConnectionManager().Return(connectionManager) trns := pluginCore.DoTransition(pluginCore.PhaseInfoQueued(time.Now(), 0, "")) for !trns.Info().Phase().IsTerminal() { diff --git a/flytepropeller/pkg/controller/nodes/task/connectionmanager/config.go b/flytepropeller/pkg/controller/nodes/task/connectionmanager/config.go index 65058cb736..9224322bb9 100644 --- a/flytepropeller/pkg/controller/nodes/task/connectionmanager/config.go +++ b/flytepropeller/pkg/controller/nodes/task/connectionmanager/config.go @@ -19,6 +19,10 @@ type Config struct { Connection map[string]Connection `json:"connection" pflag:", the connection that saves the secrets and configs"` } +func SetConfig(cfg *Config) error { + return section.SetConfig(cfg) +} + func GetConfig() *Config { return section.GetConfig().(*Config) } diff --git a/flytepropeller/pkg/controller/nodes/task/connectionmanager/connection.go b/flytepropeller/pkg/controller/nodes/task/connectionmanager/connection.go index b26069a109..e455e2d368 100644 --- a/flytepropeller/pkg/controller/nodes/task/connectionmanager/connection.go +++ b/flytepropeller/pkg/controller/nodes/task/connectionmanager/connection.go @@ -3,16 +3,19 @@ package connectionmanager import ( "context" "fmt" - "os" - flyteidl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flyte/flytestdlib/logger" ) // FileEnvConnectionManager allows retrieving secrets mounted to this process through Env Vars or Files. -type FileEnvConnectionManager struct{} +type FileEnvConnectionManager struct { + secretManager pluginCore.SecretManager +} // Get retrieves a secret from the environment of the running process or from a file. -func (f FileEnvConnectionManager) Get(_ context.Context, key string) (flyteidl.Connection, error) { +func (f FileEnvConnectionManager) Get(ctx context.Context, key string) (flyteidl.Connection, error) { cfg := GetConfig() connection, ok := cfg.Connection[key] if !ok { @@ -20,21 +23,19 @@ func (f FileEnvConnectionManager) Get(_ context.Context, key string) (flyteidl.C } secret := make(map[string]string) for k, v := range connection.Secrets { - // TODO: Read the secret from a local file - val, ok := os.LookupEnv(v) - if !ok { - return flyteidl.Connection{}, fmt.Errorf("secret not found in env [%s]", v) + val, err := f.secretManager.Get(ctx, v) + if err != nil { + logger.Errorf(ctx, "failed to get secret [%s] for connection [%s] with error: %v", v, k, err) + return flyteidl.Connection{}, err } secret[k] = val } - config := make(map[string]string) - for k, v := range connection.Configs { - config[k] = v - } - return flyteidl.Connection{Secrets: secret, Configs: config}, nil + return flyteidl.Connection{Secrets: secret, Configs: connection.Configs}, nil } func NewFileEnvConnectionManager() FileEnvConnectionManager { - return FileEnvConnectionManager{} + return FileEnvConnectionManager{ + secretManager: secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()), + } } diff --git a/flytepropeller/pkg/controller/nodes/task/connectionmanager/connection_test.go b/flytepropeller/pkg/controller/nodes/task/connectionmanager/connection_test.go new file mode 100644 index 0000000000..68f943845c --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/task/connectionmanager/connection_test.go @@ -0,0 +1,38 @@ +package connectionmanager + +import ( + "context" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "testing" + + coreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" +) + +func TestConnectionManager(t *testing.T) { + ctx := context.Background() + fakeSecretManager := &coreMocks.SecretManager{} + fakeSecretManager.OnGet(ctx, mock.Anything).Return("fake-token", nil) + + connectionManager := NewFileEnvConnectionManager() + connectionManager.secretManager = fakeSecretManager + + cfg := defaultConfig + cfg.Connection = map[string]Connection{ + "openai": { + Secrets: map[string]string{ + "openai_api_key": "api_key", + }, + Configs: map[string]string{ + "openai_organization": "flyteorg", + }, + }, + } + err := SetConfig(cfg) + assert.Nil(t, err) + + connection, err := connectionManager.Get(ctx, "openai") + assert.Nil(t, err) + assert.Equal(t, "fake-token", connection.Secrets["openai_api_key"]) + assert.Equal(t, "flyteorg", connection.Configs["openai_organization"]) +}