Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial set of Python type annotations #1151

Merged
merged 7 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[tool.mypy]
explicit_package_bases = true
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/sst/core/testingframework"

warn_unused_ignores = true

warn_return_any = true
warn_unused_configs = true

disallow_untyped_defs = true

exclude = [
'^scripts/',
'^tests/',
]

[[tool.mypy.overrides]]
module = "sst"
ignore_missing_imports = true
5 changes: 3 additions & 2 deletions scripts/format-diff
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import re
from subprocess import check_output,STDOUT
from typing import List

choke_points = [
"ser &",
Expand All @@ -11,11 +12,11 @@ choke_points = [
commit = sys.argv[1]
paths = sys.argv[2:]

def getoutput(cmd_arr):
def getoutput(cmd_arr: List[str]) -> str:
result = check_output(cmd_arr,stderr=STDOUT,stdin=None).decode("utf-8").rstrip("\n")
return result

def format_diff(commit, path):
def format_diff(commit: str, path: str) -> None:
cmd = ["git", "diff", commit, "HEAD", path ]
diff_text = getoutput(cmd)

Expand Down
69 changes: 36 additions & 33 deletions src/sst/core/model/xmlToPython.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@

import xml.etree.ElementTree as ET
import sys, os, re
from typing import Dict

import sst

def printTree(indent, node):
def printTree(indent: int, node: ET.Element) -> None:
print("%sBegin %s: %r"%(' '*indent, node.tag, node.attrib))
if node.text and len(node.text.strip()):
print("%sText: %s"%(' '*indent, node.text.strip()))
Expand All @@ -27,84 +28,86 @@ def printTree(indent, node):


# Various global lookups
sstVars = dict()
sstVars: Dict[str, str] = dict()
sstParams = dict()
sstLinks = dict()


# Some regular expressions
sdlRE = re.compile("<sdl([^/]*?)/>")
commentRE = re.compile("<!--.*?-->", re.DOTALL)
eqRE = re.compile("(<[^>]+?\w+)=([^\"\'][^\\s/>]*)") # This one is suspect
namespaceRE = re.compile("<\s*((\w+):\w+)")
eqRE = re.compile(r"(<[^>]+?\w+)=([^\"\'][^\\s/>]*)") # This one is suspect
namespaceRE = re.compile(r"<\s*((\w+):\w+)")

envVarRE = re.compile("\\${(.*?)}", re.DOTALL)
sstVarRE = re.compile("\\$([^{][a-zA-Z0-9_]+)", re.DOTALL)


def processString(str):
def processString(string: str) -> str:
"""Process a string, replacing variables and env. vars with their values"""
def replaceSSTVar(matchobj):
def replaceSSTVar(matchobj: re.Match) -> str:
varname = matchobj.group(1)
return sstVars[varname]

def replaceEnvVar(matchobj):
def replaceEnvVar(matchobj: re.Match) -> str:
varname = matchobj.group(1)
return os.getenv(varname)
var = os.getenv(varname)
assert var is not None
return var

str = envVarRE.sub(replaceEnvVar, str)
str = sstVarRE.sub(replaceSSTVar, str)
return str
string = envVarRE.sub(replaceEnvVar, string)
string = sstVarRE.sub(replaceSSTVar, string)
return string


def getLink(name):
def getLink(name: str) -> sst.Link:
if name not in sstLinks:
sstLinks[name] = sst.Link(name)
return sstLinks[name]



def getParamName(node):
def getParamName(node: ET.Element) -> str:
name = node.tag.strip()
if name[0] == "{":
ns, tag = name[1:].split("}")
return ns + ":" + tag
return name


def processParamSets(set):
for group in set:
def processParamSets(groups: ET.Element) -> None:
for group in groups:
params = dict()
for p in group:
params[getParamName(p)] = processString(p.text.strip())
params[getParamName(p)] = processString(p.text.strip()) # type: ignore
sstParams[group.tag] = params


def processVars(varNode):
def processVars(varNode: ET.Element) -> None:
for var in varNode:
sstVars[var.tag] = processString(var.text.strip())
sstVars[var.tag] = processString(var.text.strip()) # type: ignore

def processConfig(cfg):
for line in cfg.text.strip().splitlines():
def processConfig(cfg: ET.Element) -> None:
for line in cfg.text.strip().splitlines(): # type: ignore
var, val = line.split('=')
sst.setProgramOption(var, processString(val)) # strip quotes



def buildComp(compNode):
def buildComp(compNode: ET.Element) -> None:
name = processString(compNode.attrib['name'])
type = processString(compNode.attrib['type'])
comp = sst.Component(name, type)

# Process Parameters
paramsNode = compNode.find("params")
params = dict()
if paramsNode != None:
if paramsNode is not None:
if "include" in paramsNode.attrib:
for paramInc in paramsNode.attrib['include'].split(','):
params.update(sstParams[processString(paramInc)])
for p in paramsNode:
params[getParamName(p)] = processString(p.text.strip())
params[getParamName(p)] = processString(p.text.strip()) # type: ignore

comp.addParams(params)

Expand All @@ -125,29 +128,29 @@ def buildComp(compNode):



def buildGraph(graph):
def buildGraph(graph: ET.Element) -> None:
for comp in graph.findall("component"):
buildComp(comp)



def build(root):
def build(root: ET.Element) -> None:
paramSets = root.find("param_include")
vars = root.find("variables")
variables = root.find("variables")
cfg = root.find("config")
timebase = root.find("timebase")
graph = root.find("sst")


if None != vars:
processVars(vars)
if None != paramSets:
if variables is not None:
processVars(variables)
if paramSets is not None:
processParamSets(paramSets)
if None != timebase:
sst.setProgramOption('timebase', timebase.text.strip())
if None != cfg:
if timebase is not None:
sst.setProgramOption('timebase', timebase.text.strip()) # type: ignore
if cfg is not None:
processConfig(cfg)
if None != graph:
if graph is not None:
buildGraph(graph)


Expand Down
4 changes: 2 additions & 2 deletions src/sst/core/testingframework/sst_test_engine_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

################################################################################

def startup_and_run(sst_core_bin_dir, test_mode):
def startup_and_run(sst_core_bin_dir: str, test_mode: int) -> None:
""" This is the main entry point for loading and running the SST Test Frameworks
Engine.

Expand Down Expand Up @@ -138,7 +138,7 @@ def _generic_exception_handler(exc_e):

####

def _verify_test_frameworks_is_available(sst_core_frameworks_dir):
def _verify_test_frameworks_is_available(sst_core_frameworks_dir: str) -> None:
""" Ensure that all test framework files are available.
:param: sst_core_frameworks_dir = Dir of the test frameworks
"""
Expand Down
33 changes: 17 additions & 16 deletions src/sst/core/testingframework/sst_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import threading
import signal
import time
from typing import Optional

import test_engine_globals
from sst_unittest_support import *
Expand All @@ -40,7 +41,7 @@
#from test_engine_junit import junit_to_xml_report_string

if not sys.warnoptions:
import os, warnings
import warnings
warnings.simplefilter("once") # Change the filter in this process
os.environ["PYTHONWARNINGS"] = "once" # Also affect subprocesses

Expand All @@ -54,12 +55,12 @@ class SSTTestCase(unittest.TestCase):
basic resource for how to develop tests for this frameworks.
"""

def __init__(self, methodName):
def __init__(self, methodName: str) -> None:
# NOTE: __init__ is called at startup for all tests before any
# setUpModules(), setUpClass(), setUp() and the like are called.
super(SSTTestCase, self).__init__(methodName)
self.testname = methodName
parent_module_path = os.path.dirname(sys.modules[self.__class__.__module__].__file__)
parent_module_path: str = os.path.dirname(sys.modules[self.__class__.__module__].__file__) # type: ignore
self._testsuite_dirpath = parent_module_path
#log_forced("SSTTestCase: __init__() - {0}".format(self.testname))
self.initializeClass(self.testname)
Expand All @@ -68,7 +69,7 @@ def __init__(self, methodName):

###

def initializeClass(self, testname):
def initializeClass(self, testname: str) -> None:
""" The method is called by the Frameworks immediately before class is
initialized.

Expand All @@ -92,7 +93,7 @@ def initializeClass(self, testname):

###

def setUp(self):
def setUp(self) -> None:
""" The method is called by the Frameworks immediately before a test is run

**NOTICE**:
Expand All @@ -115,7 +116,7 @@ def setUp(self):

###

def tearDown(self):
def tearDown(self) -> None:
""" The method is called by the Frameworks immediately after a test finishes

**NOTICE**:
Expand All @@ -136,7 +137,7 @@ def tearDown(self):
###

@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
""" This method is called by the Frameworks immediately before the TestCase starts

**NOTICE**:
Expand All @@ -154,7 +155,7 @@ def setUpClass(cls):
###

@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
""" This method is called by the Frameworks immediately after a TestCase finishes

**NOTICE**:
Expand All @@ -171,7 +172,7 @@ def tearDownClass(cls):

###

def get_testsuite_name(self):
def get_testsuite_name(self) -> str:
""" Return the testsuite (module) name

Returns:
Expand All @@ -181,7 +182,7 @@ def get_testsuite_name(self):

###

def get_testcase_name(self):
def get_testcase_name(self) -> str:
""" Return the testcase name

Returns:
Expand All @@ -190,7 +191,7 @@ def get_testcase_name(self):
return "{0}".format(strqual(self.__class__))
###

def get_testsuite_dir(self):
def get_testsuite_dir(self) -> str:
""" Return the directory path of the testsuite that is being run

Returns:
Expand All @@ -200,7 +201,7 @@ def get_testsuite_dir(self):

###

def get_test_output_run_dir(self):
def get_test_output_run_dir(self) -> str:
""" Return the path of the test output run directory

Returns:
Expand All @@ -210,7 +211,7 @@ def get_test_output_run_dir(self):

###

def get_test_output_tmp_dir(self):
def get_test_output_tmp_dir(self) -> str:
""" Return the path of the test tmp directory

Returns:
Expand All @@ -220,7 +221,7 @@ def get_test_output_tmp_dir(self):

###

def get_test_runtime_sec(self):
def get_test_runtime_sec(self) -> float:
""" Return the current runtime (walltime) of the test

Returns:
Expand Down Expand Up @@ -377,7 +378,7 @@ def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files
### Module level support
################################################################################

def setUpModule():
def setUpModule() -> None:
""" Perform setup functions before the testing Module loads.

This function is called by the Frameworks before tests in any TestCase
Expand All @@ -400,7 +401,7 @@ def setUpModule():

###

def tearDownModule():
def tearDownModule() -> None:
""" Perform teardown functions immediately after a testing Module finishes.

This function is called by the Frameworks after all tests in all TestCases
Expand Down
Loading
Loading