Skip to content

Commit

Permalink
Do not check for secret groups during runtime (flyteorg#2355)
Browse files Browse the repository at this point in the history
* Do not check for secret groups during runtime

Signed-off-by: Thomas J. Fan <[email protected]>

* TST Adds test for no groups

Signed-off-by: Thomas J. Fan <[email protected]>

---------

Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan authored Apr 19, 2024
1 parent bc153b3 commit b0cd1e8
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 17 deletions.
2 changes: 1 addition & 1 deletion flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def configure_pyflyte_cli(main: Group) -> Group:

@staticmethod
def secret_requires_group() -> bool:
"""Return True if secrets require group entry."""
"""Return True if secrets require group entry during registration time."""
return True

@staticmethod
Expand Down
10 changes: 0 additions & 10 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def get(
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
param encode_mode, defines the mode to open files, it can either be "r" to read file, or "rb" to read binary file
"""
self.check_group_key(group)
env_var = self.get_secrets_env_var(group, key, group_version)
fpath = self.get_secrets_file(group, key, group_version)
v = os.environ.get(env_var)
Expand All @@ -380,7 +379,6 @@ def get_secrets_env_var(
"""
Returns a string that matches the ENV Variable to look for the secrets
"""
self.check_group_key(group)
l = [k.upper() for k in filter(None, (group, group_version, key))]
return f"{self._env_prefix}{'_'.join(l)}"

Expand All @@ -390,18 +388,10 @@ def get_secrets_file(
"""
Returns a path that matches the file to look for the secrets
"""
self.check_group_key(group)
l = [k.lower() for k in filter(None, (group, group_version, key))]
l[-1] = f"{self._file_prefix}{l[-1]}"
return os.path.join(self._base_dir, *l)

@staticmethod
def check_group_key(group: Optional[str]):
from flytekit.configuration.plugin import get_plugin

if get_plugin().secret_requires_group() and (group is None or group == ""):
raise ValueError("secrets group is a mandatory field.")


@dataclass(frozen=True)
class CompilationState(object):
Expand Down
6 changes: 0 additions & 6 deletions tests/flytekit/unit/core/test_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ def test_secrets_manager_default():

def test_secrets_manager_get_envvar():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_env_var("", "x")
cfg = SecretsConfig.auto()
assert sec.get_secrets_env_var("group", "test") == f"{cfg.env_prefix}GROUP_TEST"
assert sec.get_secrets_env_var("group", "test", "v1") == f"{cfg.env_prefix}GROUP_V1_TEST"
Expand All @@ -168,8 +166,6 @@ def test_secret_manager_no_group(monkeypatch):

sec = SecretsManager()
cfg = SecretsConfig.auto()
sec.check_group_key(None)
sec.check_group_key("")

assert sec.get_secrets_env_var(key="ABC") == f"{cfg.env_prefix}ABC"

Expand All @@ -180,8 +176,6 @@ def test_secret_manager_no_group(monkeypatch):

def test_secrets_manager_get_file():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_file("", "x")
cfg = SecretsConfig.auto()
assert sec.get_secrets_file("group", "test") == os.path.join(
cfg.default_dir,
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/models/core/test_security.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import Mock

import pytest

import flytekit.configuration.plugin
from flytekit.models.security import Secret

Expand All @@ -16,6 +18,17 @@ def test_secret():
assert obj2.group_version == "v1"


def test_secret_error(monkeypatch):
# Mock configuration to require groups for this test
plugin_mock = Mock()
plugin_mock.secret_requires_group.return_value = True
mock_global_plugin = {"plugin": plugin_mock}
monkeypatch.setattr(flytekit.configuration.plugin, "_GLOBAL_CONFIG", mock_global_plugin)

with pytest.raises(ValueError, match="Group is a required parameter"):
Secret(key="my_key")


def test_secret_no_group(monkeypatch):
plugin_mock = Mock()
plugin_mock.secret_requires_group.return_value = False
Expand Down

0 comments on commit b0cd1e8

Please sign in to comment.