Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jedel1043 committed Sep 9, 2024
1 parent f1234df commit 3e1e393
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 72 deletions.
73 changes: 6 additions & 67 deletions lib/charms/hpc_libs/v0/slurm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,27 +129,6 @@ def _snap(*args) -> str:
return _call("snap", *args)


def format_key(key: str) -> str:
"""Format Slurm configuration keys from SlurmCASe into kebab case.
Args:
key: Slurm configuration key to convert to kebab case.
Notes:
Slurm configuration syntax does not follow proper PascalCasing
format, so we cannot put keys directly through a kebab case converter
to get the desired format. Some additional processing is needed for
certain keys before the key can properly kebabized.
For example, without additional preprocessing, the key `CPUs` will
become `cp-us` if put through a kebabizer with being preformatted to `Cpus`.
"""
if "CPUs" in key:
key = key.replace("CPUs", "Cpus")
key = _acronym.sub(r"-", key)
return _kebabize.sub(r"-", key).lower()


class SlurmOpsError(Exception):
"""Exception raised when a slurm operation failed."""

Expand Down Expand Up @@ -233,42 +212,30 @@ def generate(self) -> None:


class _EnvManager:
"""Control configuration of environment variables used in Slurm components."""
"""Control configuration of environment variables used in Slurm components.
Every configuration value is automatically uppercased and prefixed with the service name.
"""

def __init__(self, file: Union[str, os.PathLike], prefix: str, keys: [str]) -> None:
def __init__(self, file: Union[str, os.PathLike], prefix: str) -> None:
self._file: Path = Path(file)
self._service = prefix
self._keys: [str] = keys

def _config_to_env_var(self, key: str) -> str:
"""Get the environment variable corresponding to the configuration `key`."""
return self._service.replace("-", "_").upper() + "_" + key.replace("-", "_").upper()

def get(self, key: str) -> Optional[str]:
"""Get specific environment variable for service."""
if key not in self._keys:
raise SlurmOpsError(f"invalid configuration key `{key}` for service `{self._service}`")

return dotenv.get_key(self._file, self._config_to_env_var(key))

def set(self, config: Mapping[str, Any]) -> None:
"""Set environment variable for service."""
# Check to avoid modifying the .env if the input keys are invalid.
for key in config.keys():
if key not in self._keys:
raise SlurmOpsError(
f"invalid configuration key `{key}` for service `{self._service}`"
)

for key, value in config.items():
dotenv.set_key(self._file, self._config_to_env_var(key), str(value))

def unset(self, key: str) -> None:
"""Unset environment variable for service."""
# Check to avoid modifying the .env if the input keys are invalid.
if key not in self._keys:
raise SlurmOpsError(f"invalid configuration key `{key}` for service `{self._service}`")

dotenv.unset_key(self._file, self._config_to_env_var(key))


Expand Down Expand Up @@ -307,25 +274,6 @@ class MungeManager:
def __init__(self, ops_manager: SlurmOpsManager) -> None:
self.service = ops_manager.service_manager_for(ServiceType.MUNGED)
self.key = ops_manager.munge_key_manager()
self._env_manager = ops_manager._env_manager_for(ServiceType.MUNGED)

@property
def max_thread_count(self) -> Optional[int]:
"""Get the max number of threads that munged can use."""
if not (mtc := self._env_manager.get("max-thread-count")):
return None

return int(mtc)

@max_thread_count.setter
def max_thread_count(self, count: int) -> None:
"""Set the max number of threads that munged can use."""
self._env_manager.set({"max-thread-count": count})

@max_thread_count.deleter
def max_thread_count(self, count: int) -> None:
"""Unset the max number of threads that munged can use."""
self._env_manager.unset("max-thread-count")


class PrometheusExporterManager:
Expand Down Expand Up @@ -486,15 +434,6 @@ def generate(self) -> None:
class SnapManager(SlurmOpsManager):
"""Slurm ops manager that uses Snap as its package manager."""

_ENV_KEYS = {
ServiceType.SLURMCTLD: [],
ServiceType.SLURMDBD: [],
ServiceType.MUNGED: ["max-thread-count"],
ServiceType.SLURMD: ["config-server"],
ServiceType.PROMETHEUS_EXPORTER: [],
ServiceType.SLURMRESTD: ["max-connections", "max-thread-count"],
}

def install(self) -> None:
"""Install Slurm using the `slurm` snap."""
# FIXME: Pin slurm to the stable channel
Expand All @@ -519,7 +458,7 @@ def service_manager_for(self, type: ServiceType) -> ServiceManager:
def _env_manager_for(self, type: ServiceType) -> _EnvManager:
"""Return the `_EnvManager` for the specified `ServiceType`."""
return _EnvManager(
file="/var/snap/slurm/common/.env", prefix=type.value, keys=self._ENV_KEYS[type]
file="/var/snap/slurm/common/.env", prefix=type.value
)

def munge_key_manager(self) -> MungeKeyManager:
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/test_slurm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@

@patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output")
class TestSlurmOps(TestCase):
def test_format_key(self, _) -> None:
"""Test that `kebabize` properly formats slurm keys."""
self.assertEqual(slurm.format_key("CPUs"), "cpus")
self.assertEqual(slurm.format_key("AccountingStorageHost"), "accounting-storage-host")

def test_error_message(self, *_) -> None:
"""Test that `SlurmOpsError` stores the correct message."""
message = "error message!"
Expand Down

0 comments on commit 3e1e393

Please sign in to comment.