diff --git a/pyproject.toml b/pyproject.toml index 3acf91a4a68..dde4749ada0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,3 +42,7 @@ git_describe_command = ["sh", "-c", "tag=$(git tag | grep -v '-' | sort | tail - [tool.setuptools.packages.find] include = ["ssg*"] + +[[tool.mypy.overrides]] +module = "pkg_resources" +ignore_missing_imports = true diff --git a/ssg/build_cpe.py b/ssg/build_cpe.py index 9b5eba39178..b29ae2d59a6 100644 --- a/ssg/build_cpe.py +++ b/ssg/build_cpe.py @@ -6,8 +6,8 @@ from __future__ import print_function import os import sys -import ssg.id_translate +import ssg.id_translate from .constants import oval_namespace from .constants import PREFIX_TO_NS from .utils import required_key, apply_formatting_on_dict_values @@ -207,9 +207,9 @@ class CPEItem(XCCDFEntity, Templatable): ) KEYS.update(**Templatable.KEYS) - MANDATORY_KEYS = [ + MANDATORY_KEYS = { "name", - ] + } prefix = "cpe-dict" ns = PREFIX_TO_NS[prefix] diff --git a/ssg/build_remediations.py b/ssg/build_remediations.py index 3cffd49987e..bba4f2b48a3 100644 --- a/ssg/build_remediations.py +++ b/ssg/build_remediations.py @@ -37,7 +37,7 @@ 'strategy'] REMEDIATION_ELM_KEYS = ['complexity', 'disruption', 'reboot', 'strategy'] -RemediationObject = namedtuple('remediation', ['contents', 'config']) +RemediationObject = namedtuple('RemediationObject', ['contents', 'config']) def is_supported_filename(remediation_type, filename): diff --git a/ssg/build_renumber.py b/ssg/build_renumber.py index ad3ba04e8eb..745a086599c 100644 --- a/ssg/build_renumber.py +++ b/ssg/build_renumber.py @@ -2,7 +2,7 @@ from __future__ import print_function import sys import os - +from typing import Optional from .constants import ( OSCAP_RULE, OSCAP_VALUE, oval_namespace, XCCDF12_NS, cce_uri, ocil_cs, @@ -26,8 +26,8 @@ class FileLinker(object): Bass class which represents the linking of checks to their identifiers. """ - CHECK_SYSTEM = None - CHECK_NAMESPACE = None + CHECK_SYSTEM: Optional[str] = None + CHECK_NAMESPACE: Optional[str] = None def __init__(self, translator, xccdftree, checks, output_file_name): self.translator = translator diff --git a/ssg/build_yaml.py b/ssg/build_yaml.py index 9cb01fcefb2..274c3efd1ac 100644 --- a/ssg/build_yaml.py +++ b/ssg/build_yaml.py @@ -3426,13 +3426,13 @@ class Platform(XCCDFEntity): ** XCCDFEntity.KEYS ) - MANDATORY_KEYS = [ + MANDATORY_KEYS = { "name", "xml_content", "original_expression", "bash_conditional", "ansible_conditional" - ] + } prefix = "cpe-lang" ns = PREFIX_TO_NS[prefix] diff --git a/ssg/cce.py b/ssg/cce.py index c410115a90c..c580bd4b48f 100644 --- a/ssg/cce.py +++ b/ssg/cce.py @@ -5,9 +5,7 @@ import re import random import os - - -CCE_POOLS = dict() +from typing import Dict, Type class CCEFile: @@ -136,6 +134,7 @@ def absolute_path(self): return os.path.join(self.project_root, "shared", "references", "cce-sle15-avail.txt") +CCE_POOLS: Dict[str, Type[CCEFile]] = {} CCE_POOLS["redhat"] = RedhatCCEFile CCE_POOLS["sle12"] = SLE12CCEFile CCE_POOLS["sle15"] = SLE15CCEFile diff --git a/ssg/controls.py b/ssg/controls.py index 054680630ed..a4f33e28f9c 100644 --- a/ssg/controls.py +++ b/ssg/controls.py @@ -7,7 +7,7 @@ import copy import sys from glob import glob -from typing import List, Dict +from typing import Dict, List, Set import ssg.entities.common import ssg.yaml @@ -144,7 +144,7 @@ class Control(ssg.entities.common.SelectionHandler, ssg.entities.common.XCCDFEnt description=str, rationale=str, automated=str, - status=None, + status=lambda: None, mitigation=str, artifact_description=str, status_justification=str, @@ -812,7 +812,7 @@ def __init__(self, controls_dirs: List[str], env_yaml=None, existing_rules=None) self.controls_dirs = [os.path.abspath(controls_dir) for controls_dir in controls_dirs] self.env_yaml = env_yaml self.existing_rules = existing_rules - self.policies = {} + self.policies: Dict = {} def _load(self, format): for controls_dir in self.controls_dirs: diff --git a/ssg/entities/common.py b/ssg/entities/common.py index c0fd2498d1c..aecac8a60ff 100644 --- a/ssg/entities/common.py +++ b/ssg/entities/common.py @@ -6,12 +6,12 @@ import yaml from collections import defaultdict from copy import deepcopy +from typing import Set, Dict, Callable, Any, Optional from ssg.yaml import yaml_Dumper from ..xml import ElementTree as ET, add_xhtml_namespace from ..yaml import DocumentationNotComplete, open_and_expand -from ..shims import unicode_func from ..constants import ( xhtml_namespace, @@ -112,7 +112,7 @@ def add_sub_element(parent, tag, ns, data): # and therefore it does not add child elements # we need to do a hack instead # TODO: Remove this function after we move to Markdown everywhere in SSG - ustr = unicode_func('<{0} xmlns="{3}" xmlns:xhtml="{2}">{1}').format( + ustr = str('<{0} xmlns="{3}" xmlns:xhtml="{2}">{1}').format( tag, namespaced_data, xhtml_namespace, ns) try: @@ -156,15 +156,15 @@ class XCCDFEntity(object): when entities are defined in the benchmark tree, and they are compiled into flat YAMLs to the build directory. """ - KEYS = dict( + KEYS: Dict[str, Callable[[], Optional[Any]]] = dict( id_=lambda: "", title=lambda: "", definition_location=lambda: "", ) - MANDATORY_KEYS = set() + MANDATORY_KEYS: Set[str] = set() - ALTERNATIVE_KEYS = dict() + ALTERNATIVE_KEYS: Dict[str, str] = {} GENERIC_FILENAME = "" ID_LABEL = "id" diff --git a/ssg/ext/boolean/boolean.py b/ssg/ext/boolean/boolean.py index 789cbb9f989..dddeff02993 100644 --- a/ssg/ext/boolean/boolean.py +++ b/ssg/ext/boolean/boolean.py @@ -29,19 +29,8 @@ from operator import and_ as and_operator from operator import or_ as or_operator -# Python 2 and 3 -try: - basestring # NOQA -except NameError: - basestring = str # NOQA - -# Python 2 and 3 -try: - # Python 2 - reduce # NOQA -except NameError: - # Python 3 - from functools import reduce # NOQA +from functools import reduce +from typing import Optional, Any # Set to True to enable tracing for parsing TRACE_PARSE = False @@ -208,7 +197,7 @@ def parse(self, expr, simplify=False): precedence = {self.NOT: 5, self.AND: 10, self.OR: 15, TOKEN_LPAR: 20} - if isinstance(expr, basestring): + if isinstance(expr, str): tokenized = self.tokenize(expr) else: tokenized = iter(expr) @@ -445,7 +434,7 @@ def tokenize(self, expr): - True symbols: 1 and True - False symbols: 0, False and None """ - if not isinstance(expr, basestring): + if not isinstance(expr, str): raise TypeError('expr must be string but it is %s.' % type(expr)) # mapping of lowercase token strings to a token type id for the standard @@ -556,17 +545,17 @@ class Expression(object): variable symbols. """ # Defines sort and comparison order between expressions arguments - sort_order = None + sort_order: Optional[int] = None # Store arguments aka. subterms of this expressions. # subterms are either literals or expressions. - args = tuple() + args: Any = () # True is this is a literal expression such as a Symbol, TRUE or FALSE - isliteral = False + isliteral: bool = False # True if this expression has been simplified to in canonical form. - iscanonical = False + iscanonical: bool = False # these class attributes are configured when a new BooleanAlgebra is created TRUE = None @@ -809,7 +798,7 @@ def __lt__(self, other): return self == self.FALSE return NotImplemented - __nonzero__ = __bool__ = lambda s: None + __nonzero__ = __bool__ = lambda s: None # type: ignore def pretty(self, indent=0, debug=False): """ @@ -843,7 +832,7 @@ def __repr__(self): def __call__(self): return self - __nonzero__ = __bool__ = lambda s: True + __nonzero__ = __bool__ = lambda s: True # type: ignore class _FALSE(BaseElement): @@ -871,7 +860,7 @@ def __repr__(self): def __call__(self): return self - __nonzero__ = __bool__ = lambda s: False + __nonzero__ = __bool__ = lambda s: False # type: ignore class Symbol(Expression): @@ -920,7 +909,7 @@ def __str__(self): return str(self.obj) def __repr__(self): - obj = "'%s'" % self.obj if isinstance(self.obj, basestring) else repr(self.obj) + obj = "'%s'" % self.obj if isinstance(self.obj, str) else repr(self.obj) return '%s(%s)' % (self.__class__.__name__, obj) def pretty(self, indent=0, debug=False): @@ -931,7 +920,7 @@ def pretty(self, indent=0, debug=False): if debug: debug_details += '' % (self.isliteral, self.iscanonical) - obj = "'%s'" % self.obj if isinstance(self.obj, basestring) else repr(self.obj) + obj = "'%s'" % self.obj if isinstance(self.obj, str) else repr(self.obj) return (' ' * indent) + ('%s(%s%s)' % (self.__class__.__name__, debug_details, obj)) @@ -1470,7 +1459,7 @@ class AND(DualBase): """ sort_order = 10 - _pyoperator = and_operator + _pyoperator = and_operator # type: ignore def __init__(self, arg1, arg2, *args): super(AND, self).__init__(arg1, arg2, *args) @@ -1496,7 +1485,7 @@ class OR(DualBase): """ sort_order = 25 - _pyoperator = or_operator + _pyoperator = or_operator # type: ignore def __init__(self, arg1, arg2, *args): super(OR, self).__init__(arg1, arg2, *args) diff --git a/ssg/jinja.py b/ssg/jinja.py index aa2a6c651bd..414e1bcfdd5 100644 --- a/ssg/jinja.py +++ b/ssg/jinja.py @@ -9,15 +9,8 @@ import sys import jinja2 -try: - from urllib.parse import quote -except ImportError: - from urllib import quote - -try: - from shlex import quote as shell_quote -except ImportError: - from pipes import quote as shell_quote +from urllib.parse import quote +from shlex import quote as shell_quote from .constants import JINJA_MACROS_DIRECTORY from .utils import (required_key, @@ -126,6 +119,10 @@ def __init__(self, bytecode_cache=None): ) +# Module-level cached environment for jinja environment +_jinja_env = None + + def _get_jinja_environment(substitutions_dict): """ Initializes and returns a Jinja2 Environment with custom settings and filters. @@ -145,7 +142,8 @@ def _get_jinja_environment(substitutions_dict): Returns: jinja2.Environment: The configured Jinja2 Environment instance. """ - if _get_jinja_environment.env is None: + global _jinja_env + if _jinja_env is None: bytecode_cache = None if substitutions_dict.get("jinja2_cache_enabled") == "true": bytecode_cache = jinja2.FileSystemBytecodeCache( @@ -153,22 +151,19 @@ def _get_jinja_environment(substitutions_dict): ) # TODO: Choose better syntax? - _get_jinja_environment.env = JinjaEnvironment(bytecode_cache=bytecode_cache) + _jinja_env = JinjaEnvironment(bytecode_cache=bytecode_cache) add_python_functions(substitutions_dict) - _get_jinja_environment.env.filters['banner_anchor_wrap'] = banner_anchor_wrap - _get_jinja_environment.env.filters['banner_regexify'] = banner_regexify - _get_jinja_environment.env.filters['escape_id'] = escape_id - _get_jinja_environment.env.filters['escape_regex'] = escape_regex - _get_jinja_environment.env.filters['escape_yaml_key'] = escape_yaml_key - _get_jinja_environment.env.filters['quote'] = shell_quote - _get_jinja_environment.env.filters['sha256'] = sha256 - _get_jinja_environment.env.globals.update(substitutions_dict) - preload_macros(_get_jinja_environment.env) - - return _get_jinja_environment.env - - -_get_jinja_environment.env = None + _jinja_env.filters['banner_anchor_wrap'] = banner_anchor_wrap + _jinja_env.filters['banner_regexify'] = banner_regexify + _jinja_env.filters['escape_id'] = escape_id + _jinja_env.filters['escape_regex'] = escape_regex + _jinja_env.filters['escape_yaml_key'] = escape_yaml_key + _jinja_env.filters['quote'] = shell_quote + _jinja_env.filters['sha256'] = sha256 + _jinja_env.globals.update(substitutions_dict) + preload_macros(_jinja_env) + + return _jinja_env def initialize(substitutions_dict): diff --git a/ssg/oval.py b/ssg/oval.py index 00b61c56704..81baadb49a4 100644 --- a/ssg/oval.py +++ b/ssg/oval.py @@ -19,13 +19,8 @@ ASSUMED_OVAL_VERSION_STRING = "5.11" # globals, to make recursion easier in case we encounter extend_definition -try: - ET.register_namespace("oval", ovalns) -except AttributeError: - # Legacy Python 2.6 fix, see e.g. - # https://www.programcreek.com/python/example/57552/xml.etree.ElementTree._namespace_map - from xml.etree import ElementTree as ET - ET._namespace_map[ovalns] = "oval" +ET.register_namespace("oval", ovalns) + def applicable_platforms(oval_file, oval_version_string=None): diff --git a/ssg/products.py b/ssg/products.py index 4b903928965..c08f9cb7562 100644 --- a/ssg/products.py +++ b/ssg/products.py @@ -353,7 +353,7 @@ def read_properties_from_directory(self, path): self.expand_by_acquired_data(new_symbols) -def load_product_yaml(product_yaml_path): +def load_product_yaml(product_yaml_path: str) -> Product: """ Reads a product data from disk and returns it. diff --git a/ssg/profiles.py b/ssg/profiles.py index d90b48afbbd..ee3e0fb1e2a 100644 --- a/ssg/profiles.py +++ b/ssg/profiles.py @@ -4,7 +4,9 @@ import os import sys import yaml +from typing import Optional, Tuple, List, Dict +import ssg.controls from .controls import ControlsManager, Policy from .products import ( get_profile_files_from_root, @@ -13,16 +15,6 @@ ) -if sys.version_info >= (3, 9): - dict_type = dict # Python 3.9+ supports built-in generics - list_type = list - tuple_type = tuple -else: - from typing import Dict as dict_type # Fallback for older versions - from typing import List as list_type - from typing import Tuple as tuple_type - - class ProfileSelections: """ A class to represent profile with sections of rules and variables. @@ -48,7 +40,7 @@ def __init__(self, profile_id, profile_title, product_id, product_title): self.variables = {} -def _load_product_yaml(content_dir: str, product: str) -> object: +def _load_product_yaml(content_dir: str, product: str) -> ssg.products.Product: """ Load the product YAML file and return its content as a Python object. @@ -63,7 +55,7 @@ def _load_product_yaml(content_dir: str, product: str) -> object: return load_product_yaml(file_yaml_path) -def _load_yaml_profile_file(file_path: str) -> dict_type: +def _load_yaml_profile_file(file_path: str) -> Dict: """ Load the content of a YAML file intended to profiles definitions. @@ -83,7 +75,7 @@ def _load_yaml_profile_file(file_path: str) -> dict_type: return {} -def _get_extended_profile_path(profiles_files: list, profile_name: str) -> str: +def _get_extended_profile_path(profiles_files: list, profile_name: str) -> Optional[str]: """ Retrieve the full path of a profile file from a list of profile file paths. @@ -92,10 +84,11 @@ def _get_extended_profile_path(profiles_files: list, profile_name: str) -> str: profile_name (str): The name of the profile to search for. Returns: - str: The full path of the profile file if found, otherwise None. + Optional[str]: The full path of the profile file if found, otherwise None. """ profile_file = f"{profile_name}.profile" - profile_path = next((path for path in profiles_files if profile_file in path), None) + profile_path: Optional[str] = next((path for path in profiles_files if profile_file in path), + None) return profile_path @@ -123,7 +116,7 @@ def _process_profile_extension(profile: ProfileSelections, profile_yaml: dict, return profile -def _parse_control_line(control_line: str) -> tuple_type[str, str]: +def _parse_control_line(control_line: str) -> Tuple[str, str]: """ Parses a control line string and returns a tuple containing the first and third parts of the string, separated by a colon. If the string does not contain three parts, the second element @@ -172,7 +165,7 @@ def _process_selected_rule(profile: ProfileSelections, rule: str) -> None: profile.rules.append(rule) -def _process_control(profile: ProfileSelections, control: object) -> None: +def _process_control(profile: ProfileSelections, control: ssg.controls.Control) -> None: """ Processes a control by iterating through its rules and applying the appropriate processing function. Note that at this level rules list in control can include both variables and rules. @@ -288,7 +281,7 @@ def _process_profile(profile: ProfileSelections, profile_yaml: dict, profiles_fi return profile -def _load_controls_manager(controls_dir: str, product_yaml: dict) -> object: +def _load_controls_manager(controls_dir: str, product_yaml: dict) -> ssg.controls.ControlsManager: """ Loads and initializes a ControlsManager instance. @@ -308,7 +301,7 @@ def _load_controls_manager(controls_dir: str, product_yaml: dict) -> object: return control_mgr -def _sort_profiles_selections(profiles: list) -> list_type[ProfileSelections]: +def _sort_profiles_selections(profiles: list) -> List[ProfileSelections]: """ Sorts profiles selections (rules and variables) by selections ids. @@ -326,13 +319,14 @@ def _sort_profiles_selections(profiles: list) -> list_type[ProfileSelections]: def get_profiles_from_products(content_dir: str, products: list, - sorted: bool = False) -> list_type: + sorted: bool = False) -> list: """ Retrieves profiles with respective variables from the given products. Args: content_dir (str): The directory containing the content. products (list): A list of product names to retrieve profiles from. + sorted (bool): Sorts the profile selections if true, defaults to False Returns: list: A list of ProfileVariables objects containing profile variables for each product. diff --git a/ssg/shims.py b/ssg/shims.py index 0bd9fb5bfb0..388ae1b492e 100644 --- a/ssg/shims.py +++ b/ssg/shims.py @@ -3,10 +3,7 @@ import subprocess -try: - import queue as Queue -except ImportError: - import Queue +import queue as Queue def subprocess_check_output(*popenargs, **kwargs): @@ -57,9 +54,3 @@ def input_func(prompt=None): except NameError: return input(prompt) - -unicode_func = str -try: - unicode_func = unicode -except NameError: - pass diff --git a/ssg/templates.py b/ssg/templates.py index 21de3ce5809..e8a3b7f6801 100644 --- a/ssg/templates.py +++ b/ssg/templates.py @@ -19,7 +19,7 @@ from ssg.build_cpe import ProductCPEs TemplatingLang = namedtuple( - "templating_language_attributes", + "TemplatingLang", ["name", "file_extension", "template_type", "lang_specific_dir"]) TemplateType = ssg.utils.enum("REMEDIATION", "CHECK") @@ -43,12 +43,11 @@ TEMPLATE_YAML_FILE_NAME = "template.yml" -def load_module(module_name, module_path): +def load_module(module_name: str, module_path: str): """ Loads a Python module from a given file path. - This function attempts to load a module using the `imp` module for Python 2.7 and falls back - to using `importlib` for Python 3.x. + This function attempts to load a module using `importlib`. Args: module_name (str): The name to assign to the loaded module. @@ -60,21 +59,16 @@ def load_module(module_name, module_path): Raises: ValueError: If the module cannot be loaded due to an invalid spec or loader. """ - try: - # Python 2.7 - from imp import load_source - return load_source(module_name, module_path) - except ImportError: - # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - import importlib - spec = importlib.util.spec_from_file_location(module_name, module_path) - if not spec: - raise ValueError("Error loading '%s' module" % module_path) - module = importlib.util.module_from_spec(spec) - if not spec.loader: - raise ValueError("Error loading '%s' module" % module_path) - spec.loader.exec_module(module) - return module + # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + import importlib + spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore + if not spec: + raise ValueError("Error loading '%s' module" % module_path) + module = importlib.util.module_from_spec(spec) # type: ignore + if not spec.loader: + raise ValueError("Error loading '%s' module" % module_path) + spec.loader.exec_module(module) + return module class Template: diff --git a/ssg/variables.py b/ssg/variables.py index 4a9f9b3bfcc..0219772fe54 100644 --- a/ssg/variables.py +++ b/ssg/variables.py @@ -3,28 +3,20 @@ import glob import os -import sys from collections import defaultdict +from typing import Optional, List, Dict, Any, DefaultDict + from .constants import BENCHMARKS from .profiles import get_profiles_from_products from .yaml import open_and_macro_expand_from_dir - -if sys.version_info >= (3, 9): - list_type = list # Python 3.9+ supports built-in generics - dict_type = dict -else: - from typing import List as list_type # Fallback for older versions - from typing import Dict as dict_type - - # Cache variable files and respective content to avoid multiple reads -_var_files_cache = {} -_vars_content_cache = {} +_var_files_cache: Dict[str, List[str]] = {} +_vars_content_cache: Dict[str, Any] = {} -def get_variable_files_in_folder(content_dir: str, subfolder: str) -> list_type[str]: +def get_variable_files_in_folder(content_dir: str, subfolder: str) -> List[str]: """ Retrieve a list of variable files within a specified folder in the project. @@ -41,7 +33,7 @@ def get_variable_files_in_folder(content_dir: str, subfolder: str) -> list_type[ return glob.glob(pattern, recursive=True) -def get_variable_files(content_dir: str) -> list_type[str]: +def get_variable_files(content_dir: str) -> List[str]: """ Retrieves all variable files from the specified content root directory. @@ -64,7 +56,7 @@ def get_variable_files(content_dir: str) -> list_type[str]: return variable_files -def _get_variables_content(content_dir: str) -> dict_type: +def _get_variables_content(content_dir: str) -> Dict[str, Any]: """ Retrieve the content of all variable files from the specified content root directory. @@ -107,11 +99,11 @@ def get_variable_property(content_dir: str, variable_id: str, property_name: str str: The value of the specified property for the variable. """ variables_content = _get_variables_content(content_dir) - variable_content = variables_content.get(variable_id, {}) + variable_content: Dict[str, str] = variables_content.get(variable_id, {}) return variable_content.get(property_name, '') -def get_variable_options(content_dir: str, variable_id: str = None) -> dict_type: +def get_variable_options(content_dir: str, variable_id: Optional[str] = None) -> Dict[str, Dict[str, str]]: """ Retrieve the options for specific or all variables from the content root directory. @@ -146,7 +138,7 @@ def get_variable_options(content_dir: str, variable_id: str = None) -> dict_type return all_options -def get_variables_from_profiles(profiles: list) -> dict_type: +def get_variables_from_profiles(profiles: list) -> Dict[str, Dict[str, str]]: """ Extracts variables from a list of profiles and organizes them into a nested dictionary. @@ -158,14 +150,14 @@ def get_variables_from_profiles(profiles: list) -> dict_type: keys are product names, and the third level keys are profile IDs, with the corresponding values being the variable values. """ - variables = defaultdict(lambda: defaultdict(dict)) + variables: DefaultDict[str, DefaultDict[str, Dict[str, str]]] = defaultdict(lambda: defaultdict(dict)) for profile in profiles: for variable, value in profile.variables.items(): variables[variable][profile.product_id][profile.profile_id] = value return _convert_defaultdict_to_dict(variables) -def _convert_defaultdict_to_dict(dictionary: defaultdict) -> dict_type: +def _convert_defaultdict_to_dict(dictionary: defaultdict) -> Dict[Any, Any]: """ Recursively converts a defaultdict to a regular dictionary. @@ -176,11 +168,11 @@ def _convert_defaultdict_to_dict(dictionary: defaultdict) -> dict_type: dict: The converted dictionary. """ if isinstance(dictionary, defaultdict): - dictionary = {k: _convert_defaultdict_to_dict(v) for k, v in dictionary.items()} + dictionary = {k: _convert_defaultdict_to_dict(v) for k, v in dictionary.items()} # type: ignore[assignment] return dictionary -def get_variables_by_products(content_dir: str, products: list) -> dict_type[str, dict]: +def get_variables_by_products(content_dir: str, products: list) -> Dict[str, dict]: """ Retrieve variables by products from the specified content root directory. @@ -197,11 +189,10 @@ def get_variables_by_products(content_dir: str, products: list) -> dict_type[str product-profile pairs. """ profiles = get_profiles_from_products(content_dir, products) - profiles_variables = get_variables_from_profiles(profiles) - return _convert_defaultdict_to_dict(profiles_variables) + return get_variables_from_profiles(profiles) -def get_variable_values(content_dir: str, profiles_variables: dict) -> dict_type: +def get_variable_values(content_dir: str, profiles_variables: dict) -> dict: """ Update the variables dictionary with actual values for each variable option. @@ -211,7 +202,7 @@ def get_variable_values(content_dir: str, profiles_variables: dict) -> dict_type Args: content_dir (str): The root directory of the content. - variables (dict): A dictionary where keys are variable names and values are dictionaries + profiles_variables (dict): A dictionary where keys are variable names and values are dictionaries of product-profile pairs. Returns: @@ -220,7 +211,7 @@ def get_variable_values(content_dir: str, profiles_variables: dict) -> dict_type all_variables_options = get_variable_options(content_dir) for variable in profiles_variables: - variable_options = all_variables_options.get(variable, {}) + variable_options: Dict[str, str] = all_variables_options.get(variable, {}) for product, profiles in profiles_variables[variable].items(): for profile in profiles: profile_option = profiles.get(profile, None) diff --git a/ssg/xml.py b/ssg/xml.py index ba26254aebe..9f75e67030b 100644 --- a/ssg/xml.py +++ b/ssg/xml.py @@ -21,11 +21,7 @@ cpe_language_namespace, ) - -try: - from xml.etree import cElementTree as ElementTree -except ImportError: - from xml.etree import ElementTree as ElementTree +from xml.etree import ElementTree as ElementTree def oval_generated_header(product_name, schema_version, ssg_version): @@ -102,8 +98,7 @@ def get_namespaces_from(file): # Probably an old version of Python # Doesn't matter, as this is non-essential. pass - finally: - return result + return result def open_xml(filename): diff --git a/ssg/yaml.py b/ssg/yaml.py index f7f5582a60a..c2e576a733b 100644 --- a/ssg/yaml.py +++ b/ssg/yaml.py @@ -19,17 +19,17 @@ try: from yaml import CSafeLoader as yaml_SafeLoader except ImportError: - from yaml import SafeLoader as yaml_SafeLoader + from yaml import SafeLoader as yaml_SafeLoader # type: ignore[assignment] try: from yaml import CLoader as yaml_Loader except ImportError: - from yaml import Loader as yaml_Loader + from yaml import Loader as yaml_Loader # type: ignore[assignment] try: from yaml import CDumper as yaml_Dumper except ImportError: - from yaml import Dumper as yaml_Dumper + from yaml import Dumper as yaml_Dumper # type: ignore[assignment] def _bool_constructor(self, node): """ diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ac0b7beb9d6..f560b042ff2 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -109,6 +109,7 @@ mypy_test("utils/import_disa_stig.py" "skip") mypy_test("tests/cces-removed.py" "normal") mypy_test("utils/build_control_from_reference.py" "skip") mypy_test("tests/rule_removal.py" "skip") +mypy_test("ssg" "skip") if(Python_VERSION_MAJOR GREATER 2 AND Python_VERSION_MINOR GREATER 7 AND PY_TRESTLE AND PY_LXML) mypy_test("utils/oscal/" "skip") diff --git a/utils/template_renderer.py b/utils/template_renderer.py index d4429c9ad94..533f94ccc2b 100644 --- a/utils/template_renderer.py +++ b/utils/template_renderer.py @@ -98,8 +98,8 @@ def get_result(self): abs_basedir = os.path.join(lookup_dirs[0], template_basedir) lookup_dirs.append(abs_basedir) - ssg.jinja._get_jinja_environment(dict()) - ssg.jinja._get_jinja_environment.env.loader = FlexibleLoader(lookup_dirs) + env = ssg.jinja._get_jinja_environment({}) + env.loader = FlexibleLoader(lookup_dirs) return ssg.jinja.process_file(html_jinja_template, subst_dict) def output_results(self, args):