diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c4bf5399..4413d7227 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,7 @@ jobs: - run: invoke build.install-package - run: invoke test.format + - run: invoke test.type-check - run: invoke integration.clean - run: invoke integration.version - run: invoke integration.initialize diff --git a/policy_sentry/__init__.py b/policy_sentry/__init__.py index 911ffc7cf..139a780ad 100644 --- a/policy_sentry/__init__.py +++ b/policy_sentry/__init__.py @@ -1,4 +1,6 @@ # pylint: disable=missing-module-docstring +from __future__ import annotations + import logging from logging import NullHandler @@ -6,7 +8,11 @@ logging.getLogger(__name__).addHandler(NullHandler()) -def set_stream_logger(name='policy_sentry', level=logging.DEBUG, format_string=None): +def set_stream_logger( + name: str = "policy_sentry", + level: int = logging.DEBUG, + format_string: str | None = None, +) -> None: """ Add a stream handler for the given name and level to the logging module. By default, this logs all policy_sentry messages to ``stdout``. diff --git a/policy_sentry/shared/iam_data.py b/policy_sentry/shared/iam_data.py index 7a9a7a3cb..ad9805333 100644 --- a/policy_sentry/shared/iam_data.py +++ b/policy_sentry/shared/iam_data.py @@ -22,11 +22,13 @@ @functools.lru_cache(maxsize=1) def get_iam_definition_schema_version() -> str: - return iam_definition.get(POLICY_SENTRY_SCHEMA_VERSION_NAME, POLICY_SENTRY_SCHEMA_VERSION_V1) + return iam_definition.get( + POLICY_SENTRY_SCHEMA_VERSION_NAME, POLICY_SENTRY_SCHEMA_VERSION_V1 + ) @functools.lru_cache(maxsize=1024) -def get_service_prefix_data(service_prefix: str) -> dict[str, Any] | None: +def get_service_prefix_data(service_prefix: str) -> dict[str, Any]: """ Given an AWS service prefix, return a large dictionary of IAM privilege data for processing and analysis. @@ -41,5 +43,5 @@ def get_service_prefix_data(service_prefix: str) -> dict[str, Any] | None: return result # pylint: disable=bare-except, inconsistent-return-statements except: - logger.debug("Service prefix not %s found.", service_prefix) - return None + logger.info(f"Service prefix not {service_prefix} found.") + return {} diff --git a/policy_sentry/util/access_levels.py b/policy_sentry/util/access_levels.py index f178e8be8..ba2fca427 100644 --- a/policy_sentry/util/access_levels.py +++ b/policy_sentry/util/access_levels.py @@ -1,13 +1,19 @@ """ Util methods for handling operations relating to access levels. All of these access_levels methods are specific to policy sentry internals.""" +from __future__ import annotations + import sys import logging logger = logging.getLogger(__name__) -def override_access_level(service_override_config, action_name, provided_access_level): +def override_access_level( + service_override_config: dict[str, list[str]], + action_name: str, + provided_access_level: str, +) -> str | None: """ Given the service-specific override config, determine whether or not the override config tells us to override the access level in the documentation. @@ -17,45 +23,35 @@ def override_access_level(service_override_config, action_name, provided_access_ action_name: The name of the action provided_access_level: Read, Write, List, Tagging, or 'Permissions management'. """ - real_access_level = [] # This will hold the real access level in index 0 + real_access_level = None # This will hold the real access level try: - for i in range(len(service_override_config.keys())): - keys = list(service_override_config.keys()) - actions_list = service_override_config[keys[i]] + for access_level, actions_list in service_override_config.items(): # If it exists in the list, then set the real_access_level to the key (key is read, write, list, etc.) - # Once we meet this condition, break the loop so we can return the - # value - # pylint: disable=no-else-break - if str.lower(action_name) in actions_list: - real_access_level.append(keys[i]) + # Once we meet this condition, break the loop so we can return the value + if action_name.lower() in actions_list: + real_access_level = access_level break - else: - continue except AttributeError as a_e: logger.debug( - "AttributeError: %s\nService overrides config is %s\nKeys are %s", - a_e, - service_override_config, - service_override_config.keys(), + f"AttributeError: {a_e}\nService overrides config is {service_override_config}" ) # first index will contain the access level given in the override config for that action. # since we break the loop, we know it only contains one value. - if len(real_access_level) > 0: + if real_access_level: # If AWS hasn't fixed their documentation yet, then our YAML override cfg will not match their documentation. # Therefore, accept our override instead. - if real_access_level[0] != provided_access_level: - return real_access_level[0] + if real_access_level != provided_access_level: + return real_access_level # Otherwise, they have fixed their documentation because our override file matches their documentation. # Therefore, return false because we don't need to override - else: - return False - else: - return False + return None -def transform_access_level_text(access_level): + +def transform_access_level_text(access_level: str) -> str: """This takes the Click choices for access levels, like permissions-management, and - returns the text format matching that access level, but in the format that the database expects""" + returns the text format matching that access level, but in the format that the database expects + """ if access_level == "read": level = "Read" elif access_level == "write": @@ -73,8 +69,11 @@ def transform_access_level_text(access_level): def determine_access_level_override( - service, action_name, provided_access_level, service_override_config -): + service: str, + action_name: str, + provided_access_level: str, + service_override_config: dict[str, list[str]], +) -> str | None: """ Arguments: service: service, like iam @@ -87,32 +86,33 @@ def determine_access_level_override( # and then run the rest of this for if else statement eval later. # Using str.lower to make sure we don't get failing cases if there are # minor capitalization differences - if str.lower(provided_access_level) == str.lower("Read"): + provided_access_level_lower = provided_access_level.lower() + if provided_access_level_lower == "read": override_decision = override_access_level( - service_override_config, str.lower(action_name), "Read" + service_override_config, action_name.lower(), "Read" ) - elif str.lower(provided_access_level) == str.lower("Write"): + elif provided_access_level_lower == "write": override_decision = override_access_level( - service_override_config, str.lower(action_name), "Write" + service_override_config, action_name.lower(), "Write" ) - elif str.lower(provided_access_level) == str.lower("List"): + elif provided_access_level_lower == "list": override_decision = override_access_level( - service_override_config, str.lower(action_name), "List" + service_override_config, action_name.lower(), "List" ) - elif str.lower(provided_access_level) == str.lower("Permissions management"): + elif provided_access_level_lower == "permissions management": override_decision = override_access_level( - service_override_config, str.lower(action_name), "Permissions management" + service_override_config, action_name.lower(), "Permissions management" ) - elif str.lower(provided_access_level) == str.lower("Tagging"): + elif provided_access_level_lower == "tagging": override_decision = override_access_level( - service_override_config, str.lower(action_name), "Tagging" + service_override_config, action_name.lower(), "Tagging" ) else: logger.debug( "Unknown error - determine_override_status() can't determine the access level of %s:%s during " "the scraping process. The provided access level was %s. Exiting...", service, - str.lower(action_name), + action_name.lower(), provided_access_level, ) sys.exit() diff --git a/policy_sentry/util/actions.py b/policy_sentry/util/actions.py index 7bf74a246..e7281abfb 100644 --- a/policy_sentry/util/actions.py +++ b/policy_sentry/util/actions.py @@ -1,44 +1,41 @@ """ Text operations specific to IAM actions """ +from __future__ import annotations -def get_service_from_action(action): +def get_service_from_action(action: str) -> str: """ Returns the service name from a service:action combination :param action: ec2:DescribeInstance :return: ec2 """ - service, action_name = action.split(":") # pylint: disable=unused-variable - return str.lower(service) + service = action.split(":")[0] + return service.lower() -def get_action_name_from_action(action): +def get_action_name_from_action(action: str) -> str: """ Returns the lowercase action name from a service:action combination :param action: ec2:DescribeInstance :return: describeinstance """ - service, action_name = action.split(":") # pylint: disable=unused-variable - return str.lower(action_name) + action_name = action.split(":")[1] + return action_name.lower() -def get_lowercase_action_list(action_list): +def get_lowercase_action_list(action_list: list[str]) -> list[str]: """ Given a list of actions, return the list but in lowercase format """ - new_action_list = [] - for action in action_list: - new_action_list.append(str.lower(action)) - return new_action_list + return [action.lower() for action in action_list] -def get_full_action_name(service, action_name): +def get_full_action_name(service: str, action_name: str) -> str: """ Gets the proper formatting for an action - the service, plus colon, plus action name. :param service: service name, like s3 :param action_name: action name, like createbucket :return: the resulting string """ - action = service + ":" + action_name - return action + return f"{service}:{action_name}" diff --git a/policy_sentry/util/conditions.py b/policy_sentry/util/conditions.py index dd45df619..d01ffc229 100644 --- a/policy_sentry/util/conditions.py +++ b/policy_sentry/util/conditions.py @@ -28,23 +28,22 @@ def translate_condition_key_data_types(condition_str: str) -> str: raise Exception(f"Unknown data format: {condition_lowercase}") -def get_service_from_condition_key(condition_key): +def get_service_from_condition_key(condition_key: str) -> str: """Given a condition key, return the service prefix""" - elements = condition_key.split(":", 2) - return elements[0] + return condition_key.split(":", 2)[0] -def get_comma_separated_condition_keys(condition_keys): +def get_comma_separated_condition_keys(condition_keys: str) -> str: """ :param condition_keys: String containing multiple condition keys, separated by double spaces :return: result: String containing multiple condition keys, comma-separated """ - result = condition_keys.replace(" ", ",") # replace the double spaces with a comma - return result + # replace the double spaces with a comma + return condition_keys.replace(" ", ",") -def is_condition_key_match(document_key, some_str): +def is_condition_key_match(document_key: str, some_str: str) -> bool: """ Given a documented condition key and one from a policy, determine if they match Examples: - s3:prefix and s3:prefix obviously match diff --git a/policy_sentry/util/file.py b/policy_sentry/util/file.py index 00dfcd4fd..a3ecc66d6 100755 --- a/policy_sentry/util/file.py +++ b/policy_sentry/util/file.py @@ -1,13 +1,18 @@ """ Functions that relate to manipulating files, loading files, and managing filepaths. """ +from __future__ import annotations + import logging +from pathlib import Path +from typing import Any, cast + import yaml logger = logging.getLogger(__name__) -def read_yaml_file(filename): +def read_yaml_file(filename: str | Path) -> dict[str, Any]: """ Reads a YAML file, safe loads, and returns the dictionary @@ -16,7 +21,7 @@ def read_yaml_file(filename): """ with open(filename, "r") as yaml_file: try: - cfg = yaml.safe_load(yaml_file) + cfg = cast("dict[str, Any]", yaml.safe_load(yaml_file)) except yaml.YAMLError as exc: logger.critical(exc) return cfg diff --git a/policy_sentry/util/text.py b/policy_sentry/util/text.py index 9b726d4e0..1f3a5d95a 100644 --- a/policy_sentry/util/text.py +++ b/policy_sentry/util/text.py @@ -3,7 +3,7 @@ """ -def capitalize_first_character(some_string): +def capitalize_first_character(some_string: str) -> str: """ Description: Capitalizes the first character of a string :param some_string: @@ -12,9 +12,8 @@ def capitalize_first_character(some_string): return " ".join("".join([w[0].upper(), w[1:].lower()]) for w in some_string.split()) -def strip_special_characters(some_string): +def strip_special_characters(some_string: str) -> str: """Remove all special characters, punctuation, and spaces from a string""" # Input: "Special $#! characters spaces 888323" # Output: 'Specialcharactersspaces888323' - result = ''.join(e for e in some_string if e.isalnum()) - return result + return "".join(e for e in some_string if e.isalnum()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..7ca67021a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.mypy] +files = "policy_sentry" +strict = true + +exclude = [ + '^policy_sentry/analysis', + '^policy_sentry/bin', + '^policy_sentry/command', + '^policy_sentry/querying', + '^policy_sentry/shared', + '^policy_sentry/util/(arns|policy_files)', + '^policy_sentry/writing', +] diff --git a/requirements-dev.txt b/requirements-dev.txt index 27c9c0767..f2d64933f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,6 +9,9 @@ safety==2.3.5 bandit==1.7.5 # Formatting black==22.12.0 +# Type hints +mypy==1.4.1 +types-pyyaml==6.0.12.11 # Other? Maybe this is from the docs? Not sure. # appdirs==1.4.4 # astroid==2.5.6 diff --git a/tasks.py b/tasks.py index e1f6f261a..09c757ac7 100755 --- a/tasks.py +++ b/tasks.py @@ -242,6 +242,20 @@ def run_linter(c): sys.exit(1) +# TEST - type check +@task +def run_mypy(c): + """Type checking with `mypy`""" + try: + c.run('mypy policy_sentry/') + except UnexpectedExit as u_e: + logger.critical(f"FAIL! UnexpectedExit: {u_e}") + sys.exit(1) + except Failure as f_e: + logger.critical(f"FAIL: Failure: {f_e}") + sys.exit(1) + + # UNIT TESTING @task def run_pytest(c): @@ -282,6 +296,7 @@ def build_docker(c): # test.add_task(run_full_test_suite, 'all') test.add_task(format, 'format') test.add_task(run_linter, 'lint') +test.add_task(run_mypy, 'type-check') test.add_task(security_scan, 'security') build.add_task(build_package, 'build-package')