From 48fb964ee4110acec7ea573e3e25f93493ef7ae8 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Thu, 29 Aug 2024 09:05:32 -0600 Subject: [PATCH 1/7] add configuration file for Python-based tools --- pyproject.toml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..4aa14ac12 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[tool.mypy] +explicit_package_bases = true +mypy_path = "$MYPY_CONFIG_FILE_DIR/src/sst/core/testingframework" + +warn_unused_ignores = true + +warn_return_any = true +warn_unused_configs = true + +disallow_untyped_defs = true + +exclude = [ + '^scripts/', + '^tests/', +] + +[[tool.mypy.overrides]] +module = "sst" +ignore_missing_imports = true From d751513b36e92aa70ee4f7e75badd5090e38d862 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Wed, 2 Oct 2024 13:25:50 -0600 Subject: [PATCH 2/7] fix regex escaping in xmlToPython.py --- src/sst/core/model/xmlToPython.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sst/core/model/xmlToPython.py b/src/sst/core/model/xmlToPython.py index a1d54e0a3..ad74d3667 100755 --- a/src/sst/core/model/xmlToPython.py +++ b/src/sst/core/model/xmlToPython.py @@ -35,8 +35,8 @@ def printTree(indent, node): # Some regular expressions sdlRE = re.compile("") commentRE = re.compile("", re.DOTALL) -eqRE = re.compile("(<[^>]+?\w+)=([^\"\'][^\\s/>]*)") # This one is suspect -namespaceRE = re.compile("<\s*((\w+):\w+)") +eqRE = re.compile(r"(<[^>]+?\w+)=([^\"\'][^\\s/>]*)") # This one is suspect +namespaceRE = re.compile(r"<\s*((\w+):\w+)") envVarRE = re.compile("\\${(.*?)}", re.DOTALL) sstVarRE = re.compile("\\$([^{][a-zA-Z0-9_]+)", re.DOTALL) From 616fd6792cd1bfd4a2d69c09d4dd8dad50d76ff3 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Wed, 2 Oct 2024 15:24:29 -0600 Subject: [PATCH 3/7] add type annotations and Python idioms in xmlToPython.py --- src/sst/core/model/xmlToPython.py | 54 ++++++++++++++++--------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/src/sst/core/model/xmlToPython.py b/src/sst/core/model/xmlToPython.py index ad74d3667..5c41dfc18 100755 --- a/src/sst/core/model/xmlToPython.py +++ b/src/sst/core/model/xmlToPython.py @@ -42,19 +42,21 @@ def printTree(indent, node): sstVarRE = re.compile("\\$([^{][a-zA-Z0-9_]+)", re.DOTALL) -def processString(str): +def processString(string: str) -> str: """Process a string, replacing variables and env. vars with their values""" - def replaceSSTVar(matchobj): + def replaceSSTVar(matchobj: re.Match) -> str: varname = matchobj.group(1) return sstVars[varname] - def replaceEnvVar(matchobj): + def replaceEnvVar(matchobj: re.Match) -> str: varname = matchobj.group(1) - return os.getenv(varname) + var = os.getenv(varname) + assert var is not None + return var - str = envVarRE.sub(replaceEnvVar, str) - str = sstVarRE.sub(replaceSSTVar, str) - return str + string = envVarRE.sub(replaceEnvVar, string) + string = sstVarRE.sub(replaceSSTVar, string) + return string def getLink(name): @@ -72,26 +74,26 @@ def getParamName(node): return name -def processParamSets(set): - for group in set: +def processParamSets(groups: ET.Element) -> None: + for group in groups: params = dict() for p in group: - params[getParamName(p)] = processString(p.text.strip()) + params[getParamName(p)] = processString(p.text.strip()) # type: ignore sstParams[group.tag] = params -def processVars(varNode): +def processVars(varNode: ET.Element) -> None: for var in varNode: - sstVars[var.tag] = processString(var.text.strip()) + sstVars[var.tag] = processString(var.text.strip()) # type: ignore -def processConfig(cfg): - for line in cfg.text.strip().splitlines(): +def processConfig(cfg: ET.Element) -> None: + for line in cfg.text.strip().splitlines(): # type: ignore var, val = line.split('=') sst.setProgramOption(var, processString(val)) # strip quotes -def buildComp(compNode): +def buildComp(compNode: ET.Element) -> None: name = processString(compNode.attrib['name']) type = processString(compNode.attrib['type']) comp = sst.Component(name, type) @@ -99,12 +101,12 @@ def buildComp(compNode): # Process Parameters paramsNode = compNode.find("params") params = dict() - if paramsNode != None: + if paramsNode is not None: if "include" in paramsNode.attrib: for paramInc in paramsNode.attrib['include'].split(','): params.update(sstParams[processString(paramInc)]) for p in paramsNode: - params[getParamName(p)] = processString(p.text.strip()) + params[getParamName(p)] = processString(p.text.strip()) # type: ignore comp.addParams(params) @@ -131,23 +133,23 @@ def buildGraph(graph): -def build(root): +def build(root: ET.Element) -> None: paramSets = root.find("param_include") - vars = root.find("variables") + variables = root.find("variables") cfg = root.find("config") timebase = root.find("timebase") graph = root.find("sst") - if None != vars: - processVars(vars) - if None != paramSets: + if variables is not None: + processVars(variables) + if paramSets is not None: processParamSets(paramSets) - if None != timebase: - sst.setProgramOption('timebase', timebase.text.strip()) - if None != cfg: + if timebase is not None: + sst.setProgramOption('timebase', timebase.text.strip()) # type: ignore + if cfg is not None: processConfig(cfg) - if None != graph: + if graph is not None: buildGraph(graph) From 3d5350146997b690535eecb3f1a6249975b64f9f Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Wed, 2 Oct 2024 15:27:00 -0600 Subject: [PATCH 4/7] sst_unittest_parameterized.py: remove old Python version import logic --- .../sst_unittest_parameterized.py | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/sst/core/testingframework/sst_unittest_parameterized.py b/src/sst/core/testingframework/sst_unittest_parameterized.py index 6f9c886c5..6d34b8353 100644 --- a/src/sst/core/testingframework/sst_unittest_parameterized.py +++ b/src/sst/core/testingframework/sst_unittest_parameterized.py @@ -39,20 +39,8 @@ import warnings from functools import wraps from types import MethodType as MethodType -from collections import namedtuple - -try: - from collections import OrderedDict as MaybeOrderedDict -except ImportError: - MaybeOrderedDict = dict - -from unittest import TestCase - -try: - from unittest import SkipTest -except ImportError: - class SkipTest(Exception): - pass +from collections import namedtuple, OrderedDict +from unittest import SkipTest, TestCase lzip = lambda *a: list(zip(*a)) @@ -166,9 +154,8 @@ def __repr__(self): return "param(*%r, **%r)" %self -class QuietOrderedDict(MaybeOrderedDict): - """ When OrderedDict is available, use it to make sure that the kwargs in - doc strings are consistently ordered. """ +class QuietOrderedDict(OrderedDict): + """ Have an OrderedDict visually represented as a dict. """ __str__ = dict.__str__ __repr__ = dict.__repr__ From cb564ce97fc455898d957e07b5231b476f22003e Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Wed, 2 Oct 2024 15:31:54 -0600 Subject: [PATCH 5/7] fix SSTTextTestRunner.did_tests_pass --- src/sst/core/testingframework/test_engine_unittest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sst/core/testingframework/test_engine_unittest.py b/src/sst/core/testingframework/test_engine_unittest.py index cb2e809c1..32bc2a6c9 100644 --- a/src/sst/core/testingframework/test_engine_unittest.py +++ b/src/sst/core/testingframework/test_engine_unittest.py @@ -155,7 +155,7 @@ def did_tests_pass(self, run_results): Returns: True if all tests passing with no errors, false otherwise """ - return run_results.wasSuccessful and \ + return run_results.wasSuccessful() and \ len(run_results.failures) == 0 and \ len(run_results.errors) == 0 and \ len(run_results.unexpectedSuccesses) == 0 and \ From 278b2cf2b056ea5ad2095b1a19e874447e95c843 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Thu, 29 Aug 2024 09:08:00 -0600 Subject: [PATCH 6/7] add type annotations to Python code --- scripts/format-diff | 5 +- src/sst/core/model/xmlToPython.py | 11 +- .../sst_test_engine_loader.py | 4 +- src/sst/core/testingframework/sst_unittest.py | 33 +-- .../testingframework/sst_unittest_support.py | 235 +++++++++++------- src/sst/core/testingframework/test_engine.py | 27 +- .../testingframework/test_engine_globals.py | 7 +- .../testingframework/test_engine_junit.py | 7 +- .../testingframework/test_engine_support.py | 3 +- .../testingframework/test_engine_unittest.py | 75 +++--- 10 files changed, 230 insertions(+), 177 deletions(-) diff --git a/scripts/format-diff b/scripts/format-diff index e612163a5..af8ec0ec9 100755 --- a/scripts/format-diff +++ b/scripts/format-diff @@ -3,6 +3,7 @@ import sys import re from subprocess import check_output,STDOUT +from typing import List choke_points = [ "ser &", @@ -11,11 +12,11 @@ choke_points = [ commit = sys.argv[1] paths = sys.argv[2:] -def getoutput(cmd_arr): +def getoutput(cmd_arr: List[str]) -> str: result = check_output(cmd_arr,stderr=STDOUT,stdin=None).decode("utf-8").rstrip("\n") return result -def format_diff(commit, path): +def format_diff(commit: str, path: str) -> None: cmd = ["git", "diff", commit, "HEAD", path ] diff_text = getoutput(cmd) diff --git a/src/sst/core/model/xmlToPython.py b/src/sst/core/model/xmlToPython.py index 5c41dfc18..a0b120212 100755 --- a/src/sst/core/model/xmlToPython.py +++ b/src/sst/core/model/xmlToPython.py @@ -13,10 +13,11 @@ import xml.etree.ElementTree as ET import sys, os, re +from typing import Dict import sst -def printTree(indent, node): +def printTree(indent: int, node: ET.Element) -> None: print("%sBegin %s: %r"%(' '*indent, node.tag, node.attrib)) if node.text and len(node.text.strip()): print("%sText: %s"%(' '*indent, node.text.strip())) @@ -27,7 +28,7 @@ def printTree(indent, node): # Various global lookups -sstVars = dict() +sstVars: Dict[str, str] = dict() sstParams = dict() sstLinks = dict() @@ -59,14 +60,14 @@ def replaceEnvVar(matchobj: re.Match) -> str: return string -def getLink(name): +def getLink(name: str) -> sst.Link: if name not in sstLinks: sstLinks[name] = sst.Link(name) return sstLinks[name] -def getParamName(node): +def getParamName(node: ET.Element) -> str: name = node.tag.strip() if name[0] == "{": ns, tag = name[1:].split("}") @@ -127,7 +128,7 @@ def buildComp(compNode: ET.Element) -> None: -def buildGraph(graph): +def buildGraph(graph: ET.Element) -> None: for comp in graph.findall("component"): buildComp(comp) diff --git a/src/sst/core/testingframework/sst_test_engine_loader.py b/src/sst/core/testingframework/sst_test_engine_loader.py index 8e62816ce..829748ca8 100644 --- a/src/sst/core/testingframework/sst_test_engine_loader.py +++ b/src/sst/core/testingframework/sst_test_engine_loader.py @@ -35,7 +35,7 @@ ################################################################################ -def startup_and_run(sst_core_bin_dir, test_mode): +def startup_and_run(sst_core_bin_dir: str, test_mode: int) -> None: """ This is the main entry point for loading and running the SST Test Frameworks Engine. @@ -138,7 +138,7 @@ def _generic_exception_handler(exc_e): #### -def _verify_test_frameworks_is_available(sst_core_frameworks_dir): +def _verify_test_frameworks_is_available(sst_core_frameworks_dir: str) -> None: """ Ensure that all test framework files are available. :param: sst_core_frameworks_dir = Dir of the test frameworks """ diff --git a/src/sst/core/testingframework/sst_unittest.py b/src/sst/core/testingframework/sst_unittest.py index df8bdc21e..b3a405a64 100644 --- a/src/sst/core/testingframework/sst_unittest.py +++ b/src/sst/core/testingframework/sst_unittest.py @@ -28,6 +28,7 @@ import threading import signal import time +from typing import Optional import test_engine_globals from sst_unittest_support import * @@ -40,7 +41,7 @@ #from test_engine_junit import junit_to_xml_report_string if not sys.warnoptions: - import os, warnings + import warnings warnings.simplefilter("once") # Change the filter in this process os.environ["PYTHONWARNINGS"] = "once" # Also affect subprocesses @@ -54,12 +55,12 @@ class SSTTestCase(unittest.TestCase): basic resource for how to develop tests for this frameworks. """ - def __init__(self, methodName): + def __init__(self, methodName: str) -> None: # NOTE: __init__ is called at startup for all tests before any # setUpModules(), setUpClass(), setUp() and the like are called. super(SSTTestCase, self).__init__(methodName) self.testname = methodName - parent_module_path = os.path.dirname(sys.modules[self.__class__.__module__].__file__) + parent_module_path: str = os.path.dirname(sys.modules[self.__class__.__module__].__file__) # type: ignore self._testsuite_dirpath = parent_module_path #log_forced("SSTTestCase: __init__() - {0}".format(self.testname)) self.initializeClass(self.testname) @@ -68,7 +69,7 @@ def __init__(self, methodName): ### - def initializeClass(self, testname): + def initializeClass(self, testname: str) -> None: """ The method is called by the Frameworks immediately before class is initialized. @@ -92,7 +93,7 @@ def initializeClass(self, testname): ### - def setUp(self): + def setUp(self) -> None: """ The method is called by the Frameworks immediately before a test is run **NOTICE**: @@ -115,7 +116,7 @@ def setUp(self): ### - def tearDown(self): + def tearDown(self) -> None: """ The method is called by the Frameworks immediately after a test finishes **NOTICE**: @@ -136,7 +137,7 @@ def tearDown(self): ### @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: """ This method is called by the Frameworks immediately before the TestCase starts **NOTICE**: @@ -154,7 +155,7 @@ def setUpClass(cls): ### @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: """ This method is called by the Frameworks immediately after a TestCase finishes **NOTICE**: @@ -171,7 +172,7 @@ def tearDownClass(cls): ### - def get_testsuite_name(self): + def get_testsuite_name(self) -> str: """ Return the testsuite (module) name Returns: @@ -181,7 +182,7 @@ def get_testsuite_name(self): ### - def get_testcase_name(self): + def get_testcase_name(self) -> str: """ Return the testcase name Returns: @@ -190,7 +191,7 @@ def get_testcase_name(self): return "{0}".format(strqual(self.__class__)) ### - def get_testsuite_dir(self): + def get_testsuite_dir(self) -> str: """ Return the directory path of the testsuite that is being run Returns: @@ -200,7 +201,7 @@ def get_testsuite_dir(self): ### - def get_test_output_run_dir(self): + def get_test_output_run_dir(self) -> str: """ Return the path of the test output run directory Returns: @@ -210,7 +211,7 @@ def get_test_output_run_dir(self): ### - def get_test_output_tmp_dir(self): + def get_test_output_tmp_dir(self) -> str: """ Return the path of the test tmp directory Returns: @@ -220,7 +221,7 @@ def get_test_output_tmp_dir(self): ### - def get_test_runtime_sec(self): + def get_test_runtime_sec(self) -> float: """ Return the current runtime (walltime) of the test Returns: @@ -377,7 +378,7 @@ def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files ### Module level support ################################################################################ -def setUpModule(): +def setUpModule() -> None: """ Perform setup functions before the testing Module loads. This function is called by the Frameworks before tests in any TestCase @@ -400,7 +401,7 @@ def setUpModule(): ### -def tearDownModule(): +def tearDownModule() -> None: """ Perform teardown functions immediately after a testing Module finishes. This function is called by the Frameworks after all tests in all TestCases diff --git a/src/sst/core/testingframework/sst_unittest_support.py b/src/sst/core/testingframework/sst_unittest_support.py index 91f61f17d..e4cd6ea7f 100644 --- a/src/sst/core/testingframework/sst_unittest_support.py +++ b/src/sst/core/testingframework/sst_unittest_support.py @@ -26,7 +26,7 @@ import shutil import difflib import configparser -from typing import List, Sequence, Type +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Tuple, Union import test_engine_globals from test_engine_support import OSCommand @@ -57,10 +57,10 @@ class SSTTestCaseException(Exception): """ Generic Exception support for SSTTestCase """ - def __init__(self, value): + def __init__(self, value: Exception) -> None: super(SSTTestCaseException, self).__init__(value) self.value = value - def __str__(self): + def __str__(self) -> str: return repr(self.value) ################################################################################ @@ -69,7 +69,7 @@ def __str__(self): # Commandline Information Functions ################################################################################ -def testing_check_is_in_debug_mode(): +def testing_check_is_in_debug_mode() -> bool: """ Identify if test frameworks is in debug mode Returns: @@ -79,7 +79,7 @@ def testing_check_is_in_debug_mode(): ### -def testing_check_is_in_log_failures_mode(): +def testing_check_is_in_log_failures_mode() -> bool: """ Identify if test frameworks is in log failures mode Returns: @@ -89,7 +89,7 @@ def testing_check_is_in_log_failures_mode(): ### -def testing_check_is_in_concurrent_mode(): +def testing_check_is_in_concurrent_mode() -> bool: """ Identify if test frameworks is in concurrent mode Returns: @@ -99,7 +99,7 @@ def testing_check_is_in_concurrent_mode(): ### -def testing_check_get_num_ranks(): +def testing_check_get_num_ranks() -> int: """ Get the number of ranks defined to be run during testing Returns: @@ -109,7 +109,7 @@ def testing_check_get_num_ranks(): ### -def testing_check_get_num_threads(): +def testing_check_get_num_threads() -> int: """ Get the number of threads defined to be run during testing Returns: @@ -121,7 +121,7 @@ def testing_check_get_num_threads(): # PIN Information Functions ################################################################################ -def testing_is_PIN_loaded(): +def testing_is_PIN_loaded() -> bool: # Look to see if PIN is available pindir_found = False pin_path = os.environ.get('INTEL_PIN_DIRECTORY') @@ -130,7 +130,7 @@ def testing_is_PIN_loaded(): #log_debug("testing_is_PIN_loaded() - Intel_PIN_Path = {0}; Valid Dir = {1}".format(pin_path, pindir_found)) return pindir_found -def testing_is_PIN_Compiled(): +def testing_is_PIN_Compiled() -> bool: global pin_exec_path pin_crt = sst_elements_config_include_file_get_value_int("HAVE_PINCRT", 0, True) pin_exec = sst_elements_config_include_file_get_value_str("PINTOOL_EXECUTABLE", "", True) @@ -141,7 +141,7 @@ def testing_is_PIN_Compiled(): #log_debug("testing_is_PIN_Compiled() - Rtn {0}".format(rtn)) return rtn -def testing_is_PIN2_used(): +def testing_is_PIN2_used() -> bool: from warnings import warn warn("testing_is_PIN2_used() is deprecated and will be removed in future versions of SST.", DeprecationWarning, stacklevel=2) @@ -158,7 +158,7 @@ def testing_is_PIN2_used(): #log_debug("testing_is_PIN2_used() - Rtn False because PIN Not Compiled") return False -def testing_is_PIN3_used(): +def testing_is_PIN3_used() -> bool: global pin_exec_path if testing_is_PIN_Compiled(): if testing_is_PIN2_used(): @@ -183,7 +183,7 @@ def testing_is_PIN3_used(): # System Information Functions ################################################################################ -def host_os_get_system_node_name(): +def host_os_get_system_node_name() -> str: """ Get the node name of the system Returns: @@ -195,7 +195,7 @@ def host_os_get_system_node_name(): ### -def host_os_get_kernel_type(): +def host_os_get_kernel_type() -> str: """ Get the Kernel Type Returns: @@ -205,7 +205,7 @@ def host_os_get_kernel_type(): DeprecationWarning, stacklevel=2) return platform.system() -def host_os_get_kernel_release(): +def host_os_get_kernel_release() -> str: """ Get the Kernel Release number Returns: @@ -215,7 +215,7 @@ def host_os_get_kernel_release(): DeprecationWarning, stacklevel=2) return platform.release() -def host_os_get_kernel_arch(): +def host_os_get_kernel_arch() -> str: """ Get the Kernel System Arch Returns: @@ -225,7 +225,7 @@ def host_os_get_kernel_arch(): DeprecationWarning, stacklevel=2) return platform.machine() -def host_os_get_distribution_type(): +def host_os_get_distribution_type() -> str: """ Get the os distribution type Returns: @@ -256,7 +256,7 @@ def host_os_get_distribution_type(): return OS_DIST_OSX return OS_DIST_UNDEF -def host_os_get_distribution_version(): +def host_os_get_distribution_version() -> str: """ Get the os distribution version Returns: @@ -273,7 +273,7 @@ def host_os_get_distribution_version(): ### -def host_os_is_osx(): +def host_os_is_osx() -> bool: """ Check if OS distribution is OSX Returns: @@ -283,7 +283,7 @@ def host_os_is_osx(): DeprecationWarning, stacklevel=2) return host_os_get_distribution_type() == OS_DIST_OSX -def host_os_is_linux(): +def host_os_is_linux() -> bool: """ Check if OS distribution is Linux Returns: @@ -293,7 +293,7 @@ def host_os_is_linux(): DeprecationWarning, stacklevel=2) return not host_os_get_distribution_type() == OS_DIST_OSX -def host_os_is_centos(): +def host_os_is_centos() -> bool: """ Check if OS distribution is CentOS Returns: @@ -303,7 +303,7 @@ def host_os_is_centos(): DeprecationWarning, stacklevel=2) return host_os_get_distribution_type() == OS_DIST_CENTOS -def host_os_is_rhel(): +def host_os_is_rhel() -> bool: """ Check if OS distribution is RHEL Returns: @@ -313,7 +313,7 @@ def host_os_is_rhel(): DeprecationWarning, stacklevel=2) return host_os_get_distribution_type() == OS_DIST_RHEL -def host_os_is_toss(): +def host_os_is_toss() -> bool: """ Check if OS distribution is Toss Returns: @@ -323,7 +323,7 @@ def host_os_is_toss(): DeprecationWarning, stacklevel=2) return host_os_get_distribution_type() == OS_DIST_TOSS -def host_os_is_ubuntu(): +def host_os_is_ubuntu()-> bool: """ Check if OS distribution is Ubuntu Returns: @@ -333,7 +333,7 @@ def host_os_is_ubuntu(): DeprecationWarning, stacklevel=2) return host_os_get_distribution_type() == OS_DIST_UBUNTU -def host_os_is_rocky(): +def host_os_is_rocky() -> bool: """ Check if OS distribution is Rocky Returns: @@ -346,7 +346,7 @@ def host_os_is_rocky(): ### -def host_os_get_num_cores_on_system(): +def host_os_get_num_cores_on_system() -> int: """ Get number of cores on the system Returns: @@ -361,7 +361,7 @@ def host_os_get_num_cores_on_system(): # SST Skipping Support ################################################################################ -def _testing_check_is_scenario_filtering_enabled(scenario_name): +def _testing_check_is_scenario_filtering_enabled(scenario_name: str) -> bool: """ Determine if a scenario filter name is enabled Args: @@ -375,7 +375,7 @@ def _testing_check_is_scenario_filtering_enabled(scenario_name): ### -def skip_on_scenario(scenario_name, reason): +def skip_on_scenario(scenario_name: str, reason: str) -> Callable: """ Skip a test if a scenario filter name is enabled Args: @@ -390,7 +390,7 @@ def skip_on_scenario(scenario_name, reason): ### -def skip_on_sstsimulator_conf_empty_str(section, key, reason): +def skip_on_sstsimulator_conf_empty_str(section: str, key: str, reason: str) -> Callable: """ Skip a test if a section/key in the sstsimulator.conf file is missing an entry @@ -411,7 +411,7 @@ def skip_on_sstsimulator_conf_empty_str(section, key, reason): # SST Core Configuration include file (sst_config.h.conf) Access Functions ################################################################################ -def sst_core_config_include_file_get_value_int(define, default=None, disable_warning = False): +def sst_core_config_include_file_get_value_int(define: str, default: int = None, disable_warning: bool = False) -> int: """ Retrieve a define from the SST Core Configuration Include File (sst_config.h) Args: @@ -433,7 +433,11 @@ def sst_core_config_include_file_get_value_int(define, default=None, disable_war ### -def sst_core_config_include_file_get_value_str(define, default=None, disable_warning = False): +def sst_core_config_include_file_get_value_str( + define: str, + default: str = None, + disable_warning: bool = False, +) -> str: """ Retrieve a define from the SST Core Configuration Include File (sst_config.h) Args: @@ -455,7 +459,12 @@ def sst_core_config_include_file_get_value_str(define, default=None, disable_war ### -def sst_core_config_include_file_get_value(define: str, type: Type, default=None, disable_warning: bool=False): +def sst_core_config_include_file_get_value( + define: str, + type: Type, + default: Any = None, + disable_warning: bool = False, +) -> Any: """Retrieve a define from the SST Core Configuration Include File (sst_config.h) Args: @@ -474,7 +483,11 @@ def sst_core_config_include_file_get_value(define: str, type: Type, default=None # SST Elements Configuration include file (sst_element_config.h.conf) Access Functions ################################################################################ -def sst_elements_config_include_file_get_value_int(define, default=None, disable_warning = False): +def sst_elements_config_include_file_get_value_int( + define: str, + default: int = None, + disable_warning: bool = False, +) -> int: """ Retrieve a define from the SST Elements Configuration Include File (sst_element_config.h) Args: @@ -496,7 +509,11 @@ def sst_elements_config_include_file_get_value_int(define, default=None, disable ### -def sst_elements_config_include_file_get_value_str(define, default=None, disable_warning = False): +def sst_elements_config_include_file_get_value_str( + define: str, + default: str = None, + disable_warning: bool = False, +) -> str: """ Retrieve a define from the SST Elements Configuration Include File (sst_element_config.h) Args: @@ -518,7 +535,12 @@ def sst_elements_config_include_file_get_value_str(define, default=None, disable ### -def sst_elements_config_include_file_get_value(define: str, type: Type, default=None, disable_warning: bool=False): +def sst_elements_config_include_file_get_value( + define: str, + type: Type, + default: Any = None, + disable_warning: bool = False +) -> Any: """Retrieve a define from the SST Elements Configuration Include File (sst_element_config.h) Args: @@ -537,7 +559,7 @@ def sst_elements_config_include_file_get_value(define: str, type: Type, default= # SST Configuration file (sstsimulator.conf) Access Functions ################################################################################ -def sstsimulator_conf_get_value_str(section, key, default=None): +def sstsimulator_conf_get_value_str(section: str, key: str, default: str = None) -> str: """ Retrieve a Section/Key from the SST Configuration File (sstsimulator.conf) Args: @@ -557,7 +579,7 @@ def sstsimulator_conf_get_value_str(section, key, default=None): ### -def sstsimulator_conf_get_value_int(section, key, default=None): +def sstsimulator_conf_get_value_int(section: str, key: str, default: int = None) -> int: """ Retrieve a Section/Key from the SST Configuration File (sstsimulator.conf) Args: @@ -577,7 +599,7 @@ def sstsimulator_conf_get_value_int(section, key, default=None): ### -def sstsimulator_conf_get_value_float(section, key, default=None): +def sstsimulator_conf_get_value_float(section: str, key: str, default: float = None) -> float: """ Retrieve a Section/Key from the SST Configuration File (sstsimulator.conf) Args: @@ -597,7 +619,7 @@ def sstsimulator_conf_get_value_float(section, key, default=None): ### -def sstsimulator_conf_get_value_bool(section, key, default=None): +def sstsimulator_conf_get_value_bool(section: str, key: str, default: bool = None) -> bool: """ Retrieve a Section/Key from the SST Configuration File (sstsimulator.conf) NOTE: "1", "yes", "true", and "on" will return True; @@ -620,7 +642,7 @@ def sstsimulator_conf_get_value_bool(section, key, default=None): ### -def sstsimulator_conf_get_value(section: str, key: str, type: Type, default=None): +def sstsimulator_conf_get_value(section: str, key: str, type: Type, default: Any = None) -> Any: """Get the configuration value from the SST Configuration File (sstsimulator.conf) Args: @@ -636,7 +658,7 @@ def sstsimulator_conf_get_value(section: str, key: str, type: Type, default=None ### -def sstsimulator_conf_get_sections(): +def sstsimulator_conf_get_sections() -> List[str]: """ Retrieve a list of sections that exist in the SST Configuration File (sstsimulator.conf) Returns: @@ -653,7 +675,7 @@ def sstsimulator_conf_get_sections(): ### -def sstsimulator_conf_get_section_keys(section): +def sstsimulator_conf_get_section_keys(section: str) -> List[str]: """ Retrieve a list of keys under a section that exist in the SST Configuration File (sstsimulator.conf) @@ -675,7 +697,7 @@ def sstsimulator_conf_get_section_keys(section): ### -def sstsimulator_conf_get_all_keys_values_from_section(section): +def sstsimulator_conf_get_all_keys_values_from_section(section: str) -> List[Tuple[str, str]]: """ Retrieve a list of tuples that contain all the key - value pairs under a section that exists in the SST Configuration File (sstsimulator.conf) @@ -697,7 +719,7 @@ def sstsimulator_conf_get_all_keys_values_from_section(section): ### -def sstsimulator_conf_does_have_section(section): +def sstsimulator_conf_does_have_section(section: str) -> bool: """ Check if the SST Configuration File (sstsimulator.conf) has a defined section @@ -719,7 +741,7 @@ def sstsimulator_conf_does_have_section(section): ### -def sstsimulator_conf_does_have_key(section, key): +def sstsimulator_conf_does_have_key(section: str, key: str) -> bool: """ Check if the SST Configuration File (sstsimulator.conf) has a defined key within a section Args: @@ -744,7 +766,7 @@ def sstsimulator_conf_does_have_key(section, key): # Logging Functions ################################################################################ -def log(logstr): +def log(logstr: str) -> None: """ Log a general message. This will not output unless we are outputing in >= normal mode. @@ -758,7 +780,7 @@ def log(logstr): ### -def log_forced(logstr): +def log_forced(logstr: str) -> None: """ Log a general message, no matter what the verbosity is. if in the middle of testing, it will precede with a '\\n' to slip @@ -775,7 +797,7 @@ def log_forced(logstr): ### -def log_debug(logstr): +def log_debug(logstr: str) -> None: """ Log a 'DEBUG:' message. Log will only happen if in debug verbosity mode. @@ -789,7 +811,7 @@ def log_debug(logstr): ### -def log_failure(logstr): +def log_failure(logstr: str) -> None: """ Log a test failure. Log will only happen if in log failure mode. @@ -808,7 +830,7 @@ def log_failure(logstr): ### -def log_info(logstr, forced=True): +def log_info(logstr: str, forced: bool = True) -> None: """ Log a 'INFO:' message. Args: @@ -824,7 +846,7 @@ def log_info(logstr, forced=True): log(finalstr) ### -def log_error(logstr): +def log_error(logstr: str) -> None: """ Log a 'ERROR:' message. Log will occur no matter what the verbosity is @@ -839,7 +861,7 @@ def log_error(logstr): ### -def log_warning(logstr): +def log_warning(logstr: str) -> None: """ Log a 'WARNING:' message. Log will occur no matter what the verbosity is @@ -853,7 +875,7 @@ def log_warning(logstr): ### -def log_fatal(errstr): +def log_fatal(errstr: str) -> None: """ Log a 'FATAL:' message. Log will occur no matter what the verbosity is and @@ -869,7 +891,7 @@ def log_fatal(errstr): ### -def log_testing_note(note_str): +def log_testing_note(note_str: str) -> None: """ Log a testing note Add a testing note that will be displayed at the end of the test run @@ -892,7 +914,7 @@ def log_testing_note(note_str): ### Testing Directories ################################################################################ -def test_output_get_run_dir(): +def test_output_get_run_dir() -> str: """ Return the path of the output run directory Returns: @@ -902,7 +924,7 @@ def test_output_get_run_dir(): ### -def test_output_get_tmp_dir(): +def test_output_get_tmp_dir() -> str: """ Return the path of the output tmp directory Returns: @@ -980,7 +1002,7 @@ def combine_per_rank_files(filename, header_lines_to_remove=0, remove_header_fro fp.write(line) -def testing_parse_stat(line): +def testing_parse_stat(line: str) -> Optional[List[Union[str, int, float]]]: """ Return a parsed statistic or 'None' if the line does not match a known statistic format This function will recognize an Accumulator statistic in statOutputConsole format that is generated by @@ -1014,7 +1036,13 @@ def testing_parse_stat(line): return None return stat -def testing_stat_output_diff(outfile, reffile, ignore_lines=[], tol_stats={}, new_stats=False): +def testing_stat_output_diff( + outfile: str, + reffile: str, + ignore_lines: List[str] = [], + tol_stats: Dict[str, float] = {}, + new_stats: bool = False, +) -> Tuple[bool, List, List]: """ Perform a diff of statistic outputs with special handling based on arguments This diff is not sensitive to line ordering @@ -1127,20 +1155,20 @@ class LineFilter: Returns: (bool) Filtered line or None if line should be removed """ - def filter(self, line): + def filter(self, line: str) -> Optional[str]: pass - def reset(self): + def reset(self) -> None: pass class StartsWithFilter(LineFilter): """ Filters out any line that starts with a specified string """ - def __init__(self, prefix): + def __init__(self, prefix: str) -> None: self._prefix = prefix; - def filter(self, line): + def filter(self, line: str) -> Optional[str]: """ Checks to see if the line starts with the prefix specified in constructor Args: @@ -1156,15 +1184,15 @@ def filter(self, line): class IgnoreAllAfterFilter(LineFilter): """ Filters out any line that starts with a specified string and all lines after it """ - def __init__(self, prefix, keep_line = False): + def __init__(self, prefix: str, keep_line: bool = False) -> None: self._prefix = prefix; self._keep_line = keep_line self._found = False - def reset(self): + def reset(self) -> None: self._found = False - def filter(self, line): + def filter(self, line: str) -> Optional[str]: """ Checks to see if the line starts with the prefix specified in constructor Args: @@ -1187,10 +1215,10 @@ class IgnoreWhiteSpaceFilter(LineFilter): space. Newlines are not filtered. """ - def __init__(self): + def __init__(self) -> None: pass - def filter(self, line): + def filter(self, line: str) -> Optional[str]: """ Converts any stream of whitespace to a single space. Args: @@ -1215,10 +1243,10 @@ def filter(self, line): class RemoveRegexFromLineFilter(LineFilter): """Filters out portions of line that match the specified regular expression """ - def __init__(self, expr): + def __init__(self, expr: str) -> None: self.regex = expr - def filter(self,line): + def filter(self, line: str) -> Optional[str]: """ Removes all text before the specified prefix. Args: @@ -1267,9 +1295,9 @@ def testing_compare_filtered_diff( outfile: str, reffile: str, sort: bool = False, - filters = list(), + filters: List[LineFilter] = list(), do_statistics_comparison: bool = False, -): +) -> bool: """Filter, optionally sort and then compare 2 files for a difference. Args: @@ -1291,7 +1319,7 @@ def testing_compare_filtered_diff( check_param_type("reffile", reffile, str) if issubclass(type(filters), LineFilter): - filters = [filters] + filters = [filters] # type: ignore check_param_type("filters", filters, list) if not os.path.isfile(outfile): @@ -1326,7 +1354,7 @@ def testing_compare_filtered_diff( ### -def testing_compare_diff(test_name, outfile, reffile, ignore_ws=False): +def testing_compare_diff(test_name: str, outfile: str, reffile: str, ignore_ws: bool = False) -> bool: """ compare 2 files for a diff. Args: @@ -1351,7 +1379,7 @@ def testing_compare_diff(test_name, outfile, reffile, ignore_ws=False): ### -def testing_compare_sorted_diff(test_name, outfile, reffile): +def testing_compare_sorted_diff(test_name: str, outfile: str, reffile: str) -> bool: """ Sort and then compare 2 files for a difference. Args: @@ -1374,7 +1402,7 @@ def testing_compare_sorted_diff(test_name, outfile, reffile): -def testing_compare_filtered_subset(outfile, reffile, filters=[]): +def testing_compare_filtered_subset(outfile: str, reffile: str, filters: List[LineFilter] = []) -> bool: """Filter, and then determine if outfile is a subset of reffile Args: @@ -1393,7 +1421,7 @@ def testing_compare_filtered_subset(outfile, reffile, filters=[]): check_param_type("reffile", reffile, str) if issubclass(type(filters), LineFilter): - filters = [filters] + filters = [filters] # type: ignore check_param_type("filters", filters, list) if not os.path.isfile(outfile): @@ -1445,7 +1473,7 @@ def testing_compare_filtered_subset(outfile, reffile, filters=[]): -def testing_get_diff_data(test_name): +def testing_get_diff_data(test_name: str) -> str: """ Return the diff data file from a regular diff. This should be used after a call to testing_compare_sorted_diff(), testing_compare_diff() or testing_compare_filtered_diff to @@ -1474,7 +1502,12 @@ def testing_get_diff_data(test_name): ### -def testing_merge_mpi_files(filepath_wildcard, mpiout_filename, outputfilepath, errorfilepath=None): +def testing_merge_mpi_files( + filepath_wildcard: str, + mpiout_filename: str, + outputfilepath: str, + errorfilepath: Optional[str] = None, +) -> None: """ Merge a group of common MPI files into an output file This works for OpenMPI 4.x and 5.x ONLY @@ -1528,7 +1561,7 @@ def testing_merge_mpi_files(filepath_wildcard, mpiout_filename, outputfilepath, ### -def testing_remove_component_warning_from_file(input_filepath): +def testing_remove_component_warning_from_file(input_filepath: str) -> None: """ Remove SST Component warnings from a file This will re-write back to the file with the removed warnings @@ -1545,7 +1578,7 @@ def testing_remove_component_warning_from_file(input_filepath): ### OS Basic Or Equivalent Commands ################################################################################ -def os_simple_command(os_cmd, run_dir=None, **kwargs): +def os_simple_command(os_cmd: str, run_dir: Optional[str] = None, **kwargs) -> Tuple[int, str]: """ Perform an simple os command and return a tuple of the (rtncode, rtnoutput). NOTE: Simple command cannot have pipes or redirects @@ -1565,7 +1598,7 @@ def os_simple_command(os_cmd, run_dir=None, **kwargs): rtn_data = (rtn.result(), rtn.output()) return rtn_data -def os_ls(directory=".", echo_out=True, **kwargs): +def os_ls(directory: str = ".", echo_out: bool = True, **kwargs) -> str: """ Perform an simple ls -lia shell command and dump output to screen. Args: @@ -1582,7 +1615,7 @@ def os_ls(directory=".", echo_out=True, **kwargs): log("{0}".format(rtn.output())) return rtn.output() -def os_pwd(echo_out=True, **kwargs): +def os_pwd(echo_out: bool = True, **kwargs) -> str: """ Perform an simple pwd shell command and dump output to screen. Args: @@ -1597,7 +1630,7 @@ def os_pwd(echo_out=True, **kwargs): log("{0}".format(rtn.output())) return rtn.output() -def os_cat(filepath, echo_out=True, **kwargs): +def os_cat(filepath: str, echo_out: bool = True, **kwargs) -> str: """ Perform an simple cat shell command and dump output to screen. Args: @@ -1613,7 +1646,7 @@ def os_cat(filepath, echo_out=True, **kwargs): log("{0}".format(rtn.output())) return rtn.output() -def os_symlink_file(srcdir, destdir, filename): +def os_symlink_file(srcdir: str, destdir: str, filename: str) -> None: """ Create a symlink of a file Args: @@ -1628,7 +1661,7 @@ def os_symlink_file(srcdir, destdir, filename): dstfilepath = "{0}/{1}".format(destdir, filename) os.symlink(srcfilepath, dstfilepath) -def os_symlink_dir(srcdir, destdir): +def os_symlink_dir(srcdir: str, destdir: str) -> None: """ Create a symlink of a directory Args: @@ -1639,7 +1672,7 @@ def os_symlink_dir(srcdir, destdir): check_param_type("destdir", destdir, str) os.symlink(srcdir, destdir) -def os_awk_print(in_str, fields_index_list): +def os_awk_print(in_str: str, fields_index_list: List[int]) -> str: """ Perform an awk / print (equivalent) command which returns specific fields of an input string as a string. @@ -1664,7 +1697,7 @@ def os_awk_print(in_str, fields_index_list): finalstrdata += "{0} ".format(split_list[field_index]) return finalstrdata -def os_wc(in_file, fields_index_list=[], **kwargs): +def os_wc(in_file: str, fields_index_list: List[int] = [], **kwargs) -> str: """ Run a wc (equivalent) command on a file and then extract specific fields of the result as a string. @@ -1687,7 +1720,7 @@ def os_wc(in_file, fields_index_list=[], **kwargs): wc_out = os_awk_print(wc_out, fields_index_list) return wc_out -def os_test_file(file_path, expression="-e", **kwargs): +def os_test_file(file_path: str, expression: str = "-e", **kwargs) -> bool: """ Run a shell 'test' command on a file. Args: @@ -1708,7 +1741,13 @@ def os_test_file(file_path, expression="-e", **kwargs): log_error("File {0} does not exist".format(file_path)) return False -def os_wget(fileurl, targetdir, num_tries=3, secsbetweentries=10, wgetparams=""): +def os_wget( + fileurl: str, + targetdir: str, + num_tries: int = 3, + secsbetweentries: int = 10, + wgetparams: str = "", +) -> bool: """ Perform a wget command to download a file from a url. Args: @@ -1770,7 +1809,7 @@ def os_wget(fileurl, targetdir, num_tries=3, secsbetweentries=10, wgetparams="") return wget_success -def os_extract_tar(tarfilepath, targetdir="."): +def os_extract_tar(tarfilepath: str, targetdir: str = ".") -> bool: """ Extract directories/files from a tar file. Args: @@ -1801,7 +1840,7 @@ def os_extract_tar(tarfilepath, targetdir="."): ### Platform Specific Support Functions ################################################################################ -def _get_linux_distribution(): +def _get_linux_distribution() -> Tuple[str, str]: """ Return the linux distribution info as a tuple""" # The method linux_distribution is depricated in deprecated in Py3.5 _linux_distribution = getattr(platform, 'linux_distribution', None) @@ -1833,7 +1872,7 @@ def _get_linux_distribution(): ### -def _get_linux_version(filepath, sep): +def _get_linux_version(filepath: str, sep: str) -> str: """ return the linux OS version as a string""" # Find the first digit + period in the tokenized string list with open(filepath, 'r') as filehandle: @@ -1855,8 +1894,14 @@ def _get_linux_version(filepath, sep): ### Generic Internal Support Functions ################################################################################ -def _get_sst_config_include_file_value(include_dict, include_source, define, default=None, - data_type=str, disable_warning = False): +def _get_sst_config_include_file_value( + include_dict: Dict[str, Union[str, int]], + include_source: str, + define: str, + default: Optional[Union[str, int]] = None, + data_type: Type = str, + disable_warning: bool = False, +) -> Optional[Union[str, int]]: """ Retrieve a define from an SST Configuration Include File (sst_config.h or sst-element_config.h) include_dict (dict): The dictionary to search for the define include_source (str): The name of the include file we are searching @@ -1929,7 +1974,7 @@ def _handle_config_err(exc_e, default_rtn_data): ### -def _remove_lines_with_string_from_file(removestring, input_filepath): +def _remove_lines_with_string_from_file(removestring: str, input_filepath: str) -> None: bad_strings = [removestring] # Create a temp file diff --git a/src/sst/core/testingframework/test_engine.py b/src/sst/core/testingframework/test_engine.py index b98548e5c..d03b47245 100644 --- a/src/sst/core/testingframework/test_engine.py +++ b/src/sst/core/testingframework/test_engine.py @@ -22,6 +22,7 @@ import argparse import shutil import configparser +from typing import Any, Dict, List import test_engine_globals from sst_unittest import * @@ -103,12 +104,12 @@ ################################################################################ -class TestEngine(): +class TestEngine: """ This is the main Test Engine, it will init arguments, parsed params, create output directories, and then Discover and Run the tests. """ - def __init__(self, sst_core_bin_dir, test_mode): + def __init__(self, sst_core_bin_dir: str, test_mode: int) -> None: """ Initialize the TestEngine object, and parse the user arguments. Args: @@ -209,7 +210,7 @@ def _build_tests_list_helper(self, suite): ################################################################################ ################################################################################ - def _parse_arguments(self): + def _parse_arguments(self) -> None: """ Parse the cmd line arguments.""" # Build a parameter parser, adjust its help based upon the test type helpdesc = HELP_DESC.format(self._test_type_str) @@ -357,7 +358,7 @@ def _decode_parsed_arguments(self, args, parser): #### - def _display_startup_info(self): + def _display_startup_info(self) -> None: """ Display the Test Engine Startup Information""" ver = sys.version_info @@ -437,7 +438,7 @@ def _display_startup_info(self): #### - def _create_all_output_directories(self): + def _create_all_output_directories(self) -> None: """ Create the output directories if needed""" top_dir = test_engine_globals.TESTOUTPUT_TOPDIRPATH run_dir = test_engine_globals.TESTOUTPUT_RUNDIRPATH @@ -458,7 +459,7 @@ def _create_all_output_directories(self): #### - def _discover_testsuites(self): + def _discover_testsuites(self) -> None: """ Figure out the list of paths we are searching for testsuites. The user may have given us a list via the cmd line, so that takes priority """ @@ -495,7 +496,7 @@ def _discover_testsuites(self): #### - def _add_testsuites_from_identifed_paths(self): + def _add_testsuites_from_identifed_paths(self) -> None: """ Look at all the searchable testsuite paths in the list. If its a file, try to add that testsuite directly. If its a directory; add all testsuites that match the identifed testsuite types. @@ -535,7 +536,7 @@ def _add_testsuites_from_identifed_paths(self): #### - def _create_core_config_parser(self): + def _create_core_config_parser(self) -> configparser.RawConfigParser: """ Create an Core Configurtion (INI format) parser. This will allow us to search the Core configuration looking for test file paths. @@ -565,7 +566,7 @@ def _create_core_config_parser(self): ### - def _build_core_config_include_defs_dict(self): + def _build_core_config_include_defs_dict(self) -> Dict[str, str]: """ Create a dictionary of settings from the sst_config.h. This will allow us to search the includes that the core provides. @@ -588,7 +589,7 @@ def _build_core_config_include_defs_dict(self): ### - def _build_elem_config_include_defs_dict(self): + def _build_elem_config_include_defs_dict(self) -> Dict[str, str]: """ Create a dictionary of settings from the sst_element_config.h. This will allow us to search the includes that the elements provides. Note: The Frameworks is runnable even if elements are not built or @@ -624,7 +625,7 @@ def _build_elem_config_include_defs_dict(self): ### - def _read_config_include_defs_dict(self, conf_include_path): + def _read_config_include_defs_dict(self, conf_include_path: str) -> Dict[str, str]: # Read in the file line by line and discard any lines # that do not start with "#define " rtn_dict = {} @@ -647,7 +648,7 @@ def _read_config_include_defs_dict(self, conf_include_path): ### - def _build_list_of_testsuite_dirs(self): + def _build_list_of_testsuite_dirs(self) -> List[str]: """ Using a config file parser, build a list of Test Suite Dirs. Note: The discovery method of Test Suites is different @@ -695,7 +696,7 @@ def _build_list_of_testsuite_dirs(self): #### - def _create_output_dir(self, out_dir): + def _create_output_dir(self, out_dir: str) -> bool: """ Look to see if an output dir exists. If not, try to create it :param: out_dir = The path to the output directory. diff --git a/src/sst/core/testingframework/test_engine_globals.py b/src/sst/core/testingframework/test_engine_globals.py index 199f0485c..38e365df8 100644 --- a/src/sst/core/testingframework/test_engine_globals.py +++ b/src/sst/core/testingframework/test_engine_globals.py @@ -15,6 +15,11 @@ """ import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import configparser + # Verbose Defines VERBOSE_QUIET = 0 VERBOSE_NORMAL = 1 @@ -49,7 +54,7 @@ # These are some globals to pass data between the top level test engine # and the lower level testscripts -def init_test_engine_globals(): +def init_test_engine_globals() -> None: """ Initialize the test global variables """ global TESTRUN_TESTRUNNINGFLAG global TESTRUN_SINGTHREAD_TESTSUITE_NAME diff --git a/src/sst/core/testingframework/test_engine_junit.py b/src/sst/core/testingframework/test_engine_junit.py index 3333a3a05..c3118c74a 100644 --- a/src/sst/core/testingframework/test_engine_junit.py +++ b/src/sst/core/testingframework/test_engine_junit.py @@ -47,10 +47,7 @@ import re import xml.etree.ElementTree as ET import xml.dom.minidom - -def _iteritems(_d, **kw): - """ Py3 iteritems() """ - return iter(_d.items(**kw)) +from typing import IO, List, Mapping, Optional ################################################################################ @@ -385,7 +382,7 @@ def junit_to_xml_report_string(test_suites, prettyprint=True, encoding=None): for key in ["time"]: attributes[key] += float(ts_xml.get(key, 0)) xml_element.append(ts_xml) - for key, value in _iteritems(attributes): + for key, value in attributes.items(): xml_element.set(key, str(value)) # Add the name of the testing Frameworks diff --git a/src/sst/core/testingframework/test_engine_support.py b/src/sst/core/testingframework/test_engine_support.py index a5923ab52..92791aa99 100644 --- a/src/sst/core/testingframework/test_engine_support.py +++ b/src/sst/core/testingframework/test_engine_support.py @@ -25,12 +25,13 @@ import inspect import signal from subprocess import TimeoutExpired +from typing import Any, Dict, List, Optional import test_engine_globals ################################################################################ -class OSCommand(): +class OSCommand: """ Enables to run subprocess commands in a different thread with a TIMEOUT option. This will return a OSCommandResult object. diff --git a/src/sst/core/testingframework/test_engine_unittest.py b/src/sst/core/testingframework/test_engine_unittest.py index 32bc2a6c9..e06f7dcee 100644 --- a/src/sst/core/testingframework/test_engine_unittest.py +++ b/src/sst/core/testingframework/test_engine_unittest.py @@ -21,6 +21,7 @@ import threading import time import datetime +from typing import Callable, Dict, List, Optional, TextIO, Tuple, Any if sys.version_info.minor >= 11: def get_current_time(): @@ -31,7 +32,7 @@ def get_current_time(): ################################################################################ -def check_module_conditional_import(module_name): +def check_module_conditional_import(module_name: str) -> bool: """ Test to see if we can import a module See: https://stackoverflow.com/questions/14050281/how-to-check-if-a-python-module-exists-without-importing-it @@ -85,7 +86,7 @@ def check_module_conditional_import(module_name): ################################################################################ -def verify_concurrent_test_engine_available(): +def verify_concurrent_test_engine_available() -> None: """ Check to see if we can load testtools if the user wants to run in concurrent mode. @@ -459,7 +460,7 @@ def addUnexpectedSuccess(self, test): ### - def printErrors(self): + def printErrors(self) -> None: if self.dots or self.showAll: self.stream.writeln() log("=" * 70) @@ -668,59 +669,59 @@ class SSTTestSuiteResultData: """ Support class to hold result data for a specific testsuite Results are stored as lists of test names """ - def __init__(self): - self._tests_passing = [] - self._tests_failing = [] - self._tests_errored = [] - self._tests_skiped = [] - self._tests_expectedfailed = [] - self._tests_unexpectedsuccess = [] - - def add_success(self, test): + def __init__(self) -> None: + self._tests_passing: List[SSTTestCase] = [] + self._tests_failing: List[SSTTestCase] = [] + self._tests_errored: List[SSTTestCase] = [] + self._tests_skiped: List[SSTTestCase] = [] + self._tests_expectedfailed: List[SSTTestCase] = [] + self._tests_unexpectedsuccess: List[SSTTestCase] = [] + + def add_success(self, test: SSTTestCase) -> None: """ Add a test to the success record""" self._tests_passing.append(test) - def add_failure(self, test): + def add_failure(self, test: SSTTestCase) -> None: """ Add a test to the failure record""" self._tests_failing.append(test) - def add_error(self, test): + def add_error(self, test: SSTTestCase) -> None: """ Add a test to the error record""" self._tests_errored.append(test) - def add_skip(self, test): + def add_skip(self, test: SSTTestCase) -> None: """ Add a test to the skip record""" self._tests_skiped.append(test) - def add_expected_failure(self, test): + def add_expected_failure(self, test: SSTTestCase) -> None: """ Add a test to the expected failure record""" self._tests_expectedfailed.append(test) - def add_unexpected_success(self, test): + def add_unexpected_success(self, test: SSTTestCase) -> None: """ Add a test to the unexpected success record""" self._tests_unexpectedsuccess.append(test) - def get_passing(self): + def get_passing(self) -> List[SSTTestCase]: """ Return the tests passing list""" return self._tests_passing - def get_failed(self): + def get_failed(self) -> List[SSTTestCase]: """ Return the tests failed list""" return self._tests_failing - def get_errored(self): + def get_errored(self) -> List[SSTTestCase]: """ Return the tests errored list""" return self._tests_errored - def get_skiped(self): + def get_skiped(self) -> List[SSTTestCase]: """ Return the tests skipped list""" return self._tests_skiped - def get_expectedfailed(self): + def get_expectedfailed(self) -> List[SSTTestCase]: """ Return the expected failed list""" return self._tests_expectedfailed - def get_unexpectedsuccess(self): + def get_unexpectedsuccess(self) -> List[SSTTestCase]: """ Return the tests unexpected success list""" return self._tests_unexpectedsuccess @@ -729,34 +730,34 @@ def get_unexpectedsuccess(self): class SSTTestSuitesResultsDict: """ Support class handle of dict of result data for all testsuites """ - def __init__(self): - self.testsuitesresultsdict = {} + def __init__(self) -> None: + self.testsuitesresultsdict: Dict[str, SSTTestSuiteResultData] = {} - def add_success(self, test): + def add_success(self, test: SSTTestCase) -> None: """ Add a testsuite and test to the success record""" self._get_testresult_from_testmodulecase(test).add_success(test) - def add_failure(self, test): + def add_failure(self, test: SSTTestCase) -> None: """ Add a testsuite and test to the failure record""" self._get_testresult_from_testmodulecase(test).add_failure(test) - def add_error(self, test): + def add_error(self, test: SSTTestCase) -> None: """ Add a testsuite and test to the error record""" self._get_testresult_from_testmodulecase(test).add_error(test) - def add_skip(self, test): + def add_skip(self, test: SSTTestCase) -> None: """ Add a testsuite and test to the skip record""" self._get_testresult_from_testmodulecase(test).add_skip(test) - def add_expected_failure(self, test): + def add_expected_failure(self, test: SSTTestCase) -> None: """ Add a testsuite and test to the expected failure record""" self._get_testresult_from_testmodulecase(test).add_expected_failure(test) - def add_unexpected_success(self, test): + def add_unexpected_success(self, test: SSTTestCase) -> None: """ Add a testsuite and test to the unexpected success record""" self._get_testresult_from_testmodulecase(test).add_unexpected_success(test) - def log_all_results(self): + def log_all_results(self) -> None: """ Log all result catagories by testsuite """ # Log the data by key for tmtc_name in self.testsuitesresultsdict: @@ -774,7 +775,7 @@ def log_all_results(self): for testname in self.testsuitesresultsdict[tmtc_name].get_unexpectedsuccess(): log(" - UNEXPECTED SUCCESS : {0}".format(testname)) - def log_fail_error_skip_unexpeced_results(self): + def log_fail_error_skip_unexpeced_results(self) -> None: """ Log non-success result catagories by testsuite """ # Log the data by key for tmtc_name in self.testsuitesresultsdict: @@ -795,18 +796,18 @@ def log_fail_error_skip_unexpeced_results(self): for testname in self.testsuitesresultsdict[tmtc_name].get_unexpectedsuccess(): log(" - UNEXPECTED SUCCESS : {0}".format(testname)) - def _get_testresult_from_testmodulecase(self, test): + def _get_testresult_from_testmodulecase(self, test: SSTTestCase) -> SSTTestSuiteResultData: tm_tc = self._get_test_module_test_case_name(test) if tm_tc not in self.testsuitesresultsdict.keys(): self.testsuitesresultsdict[tm_tc] = SSTTestSuiteResultData() return self.testsuitesresultsdict[tm_tc] - def _get_test_module_test_case_name(self, test): + def _get_test_module_test_case_name(self, test: SSTTestCase) -> str: return "{0}.{1}".format(self._get_test_module_name(test), self._get_test_case_name(test)) - def _get_test_case_name(self, test): + def _get_test_case_name(self, test: SSTTestCase) -> str: return strqual(test.__class__) - def _get_test_module_name(self, test): + def _get_test_module_name(self, test: SSTTestCase) -> str: return strclass(test.__class__) From 5709836d60801e38940eb794fa6ff67500a2ffb4 Mon Sep 17 00:00:00 2001 From: Eric Berquist Date: Tue, 8 Oct 2024 13:23:40 -0600 Subject: [PATCH 7/7] allow multiple heartbeats when running tests for multiple ranks or threads --- tests/testsuite_default_RealTime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testsuite_default_RealTime.py b/tests/testsuite_default_RealTime.py index b4aa993ae..4d95bfc7c 100644 --- a/tests/testsuite_default_RealTime.py +++ b/tests/testsuite_default_RealTime.py @@ -241,7 +241,7 @@ def test_RealTime_SIGUSR1_heartbeat(self): ranks = testing_check_get_num_ranks() threads = testing_check_get_num_threads() num_para = threads * ranks - self.assertTrue(hb_count == 1, "Heartbeat count incorrect, should be {0}, found {1} in {2}".format(num_para,hb_count,outfile)) + self.assertTrue(hb_count >= 1, "Heartbeat count incorrect, should be >= 1, found {0} in {1}".format(hb_count,outfile)) self.assertTrue(exit_count == num_para, "Exit message count incorrect, should be {0}, found {1} in {2}".format(num_para,exit_count,outfile)) @@ -271,7 +271,7 @@ def test_RealTime_SIGUSR2_heartbeat(self): ranks = testing_check_get_num_ranks() threads = testing_check_get_num_threads() num_para = threads * ranks - self.assertTrue(hb_count == 1, "Heartbeat count incorrect, should be {0}, found {1} in {2}".format(num_para,hb_count,outfile)) + self.assertTrue(hb_count >= 1, "Heartbeat count incorrect, should be >= 1, found {0} in {1}".format(hb_count,outfile)) self.assertTrue(exit_count == num_para, "Exit message count incorrect, should be {0}, found {1} in {2}".format(num_para,exit_count,outfile)) # Test SIGALRM + core status + heartbeat action via --sigalrm=