diff --git a/roles/slurm/files/test_update_config.py b/roles/slurm/files/test_update_config.py new file mode 100644 index 00000000..99f437b2 --- /dev/null +++ b/roles/slurm/files/test_update_config.py @@ -0,0 +1,47 @@ +import pytest +import requests_mock +import yaml + +import update_config + + +@pytest.mark.parametrize( + "data", + [ + (""" + VM.Standard2.1: + 1: 1 + 2: 1 + 3: 1 + VM.Standard2.2: + 1: 1 + 2: 1 + 3: 1 + """), + ] +) +def test_get_limits(mocker, data): + mocker.patch("update_config.load_yaml", return_value=yaml.safe_load(data)) + assert isinstance(update_config.get_limits(), dict) + + +@pytest.mark.parametrize( + "data,error", + [ + (""" + VM.Standard2.1: + 1: 1 + 2: 1 + 3: 1 + VM.Standard2.2: + 1: 1 + 2: 1 + 3: 1 + """, + SyntaxError), + ] +) +def test_get_limits_errors(mocker, data, error): + mocker.patch("update_config.load_yaml", return_value=yaml.safe_load(data)) + with pytest.raises(error): + update_config.get_limits() diff --git a/roles/slurm/files/update_config.py b/roles/slurm/files/update_config.py index 172a0409..e58107f1 100644 --- a/roles/slurm/files/update_config.py +++ b/roles/slurm/files/update_config.py @@ -16,7 +16,19 @@ def get_limits() -> Dict[str, Dict[str, str]]: Until OCI has an API to fetch service limits, we have to hard-code them in a file. """ - return load_yaml("limits.yaml") + limits = load_yaml("limits.yaml") + for mappings in limits.values(): + if not isinstance(mappings, dict): + raise SyntaxError + for ad, count in mappings.items(): + if not isinstance(ad, int): + raise SyntaxError + if not isinstance(count, int): + raise SyntaxError + for shape in limits: + if not re.match(r"", shape): + raise ValueError + return limits def get_shapes() -> Dict[str, Dict[str, str]]: @@ -78,7 +90,12 @@ def get_node_configs(limits, shapes, mgmt_info): slurm_conf_filename = "/mnt/shared/etc/slurm/slurm.conf" - node_config = "\n".join(get_node_configs(get_limits(), get_shapes(), get_mgmt_info())) + try: + limits = get_limits() + except SyntaxError: + print("ERROR: Syntax error in `limits.yaml`.") + exit(1) + node_config = "\n".join(get_node_configs(limits, get_shapes(), get_mgmt_info())) chop = re.compile('(?<=# STARTNODES\n)(.*?)(?=\n?# ENDNODES)', re.DOTALL)