Skip to content

Commit

Permalink
add initial setup for type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
gruebel committed Sep 2, 2023
1 parent 469f480 commit a4c973e
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 70 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- run: invoke build.install-package
- run: invoke test.format
- run: invoke test.type-check
- run: invoke integration.clean
- run: invoke integration.version
- run: invoke integration.initialize
Expand Down
8 changes: 7 additions & 1 deletion policy_sentry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# pylint: disable=missing-module-docstring
from __future__ import annotations

import logging
from logging import NullHandler

# Set default handler when policy_sentry is used as library to avoid "No handler found" warnings.
logging.getLogger(__name__).addHandler(NullHandler())


def set_stream_logger(name='policy_sentry', level=logging.DEBUG, format_string=None):
def set_stream_logger(
name: str = "policy_sentry",
level: int = logging.DEBUG,
format_string: str | None = None,
) -> None:
"""
Add a stream handler for the given name and level to the logging module.
By default, this logs all policy_sentry messages to ``stdout``.
Expand Down
10 changes: 6 additions & 4 deletions policy_sentry/shared/iam_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@

@functools.lru_cache(maxsize=1)
def get_iam_definition_schema_version() -> str:
return iam_definition.get(POLICY_SENTRY_SCHEMA_VERSION_NAME, POLICY_SENTRY_SCHEMA_VERSION_V1)
return iam_definition.get(
POLICY_SENTRY_SCHEMA_VERSION_NAME, POLICY_SENTRY_SCHEMA_VERSION_V1
)


@functools.lru_cache(maxsize=1024)
def get_service_prefix_data(service_prefix: str) -> dict[str, Any] | None:
def get_service_prefix_data(service_prefix: str) -> dict[str, Any]:
"""
Given an AWS service prefix, return a large dictionary of IAM privilege data for processing and analysis.
Expand All @@ -41,5 +43,5 @@ def get_service_prefix_data(service_prefix: str) -> dict[str, Any] | None:
return result
# pylint: disable=bare-except, inconsistent-return-statements
except:
logger.debug("Service prefix not %s found.", service_prefix)
return None
logger.info(f"Service prefix not {service_prefix} found.")
return {}
76 changes: 38 additions & 38 deletions policy_sentry/util/access_levels.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""
Util methods for handling operations relating to access levels.
All of these access_levels methods are specific to policy sentry internals."""
from __future__ import annotations

import sys
import logging

logger = logging.getLogger(__name__)


def override_access_level(service_override_config, action_name, provided_access_level):
def override_access_level(
service_override_config: dict[str, list[str]],
action_name: str,
provided_access_level: str,
) -> str | None:
"""
Given the service-specific override config, determine whether or not the
override config tells us to override the access level in the documentation.
Expand All @@ -17,45 +23,35 @@ def override_access_level(service_override_config, action_name, provided_access_
action_name: The name of the action
provided_access_level: Read, Write, List, Tagging, or 'Permissions management'.
"""
real_access_level = [] # This will hold the real access level in index 0
real_access_level = None # This will hold the real access level
try:
for i in range(len(service_override_config.keys())):
keys = list(service_override_config.keys())
actions_list = service_override_config[keys[i]]
for access_level, actions_list in service_override_config.items():
# If it exists in the list, then set the real_access_level to the key (key is read, write, list, etc.)
# Once we meet this condition, break the loop so we can return the
# value
# pylint: disable=no-else-break
if str.lower(action_name) in actions_list:
real_access_level.append(keys[i])
# Once we meet this condition, break the loop so we can return the value
if action_name.lower() in actions_list:
real_access_level = access_level
break
else:
continue
except AttributeError as a_e:
logger.debug(
"AttributeError: %s\nService overrides config is %s\nKeys are %s",
a_e,
service_override_config,
service_override_config.keys(),
f"AttributeError: {a_e}\nService overrides config is {service_override_config}"
)
# first index will contain the access level given in the override config for that action.
# since we break the loop, we know it only contains one value.
if len(real_access_level) > 0:
if real_access_level:
# If AWS hasn't fixed their documentation yet, then our YAML override cfg will not match their documentation.
# Therefore, accept our override instead.
if real_access_level[0] != provided_access_level:
return real_access_level[0]
if real_access_level != provided_access_level:
return real_access_level
# Otherwise, they have fixed their documentation because our override file matches their documentation.
# Therefore, return false because we don't need to override
else:
return False
else:
return False

return None

def transform_access_level_text(access_level):

def transform_access_level_text(access_level: str) -> str:
"""This takes the Click choices for access levels, like permissions-management, and
returns the text format matching that access level, but in the format that the database expects"""
returns the text format matching that access level, but in the format that the database expects
"""
if access_level == "read":
level = "Read"
elif access_level == "write":
Expand All @@ -73,8 +69,11 @@ def transform_access_level_text(access_level):


def determine_access_level_override(
service, action_name, provided_access_level, service_override_config
):
service: str,
action_name: str,
provided_access_level: str,
service_override_config: dict[str, list[str]],
) -> str | None:
"""
Arguments:
service: service, like iam
Expand All @@ -87,32 +86,33 @@ def determine_access_level_override(
# and then run the rest of this for if else statement eval later.
# Using str.lower to make sure we don't get failing cases if there are
# minor capitalization differences
if str.lower(provided_access_level) == str.lower("Read"):
provided_access_level_lower = provided_access_level.lower()
if provided_access_level_lower == "read":
override_decision = override_access_level(
service_override_config, str.lower(action_name), "Read"
service_override_config, action_name.lower(), "Read"
)
elif str.lower(provided_access_level) == str.lower("Write"):
elif provided_access_level_lower == "write":
override_decision = override_access_level(
service_override_config, str.lower(action_name), "Write"
service_override_config, action_name.lower(), "Write"
)
elif str.lower(provided_access_level) == str.lower("List"):
elif provided_access_level_lower == "list":
override_decision = override_access_level(
service_override_config, str.lower(action_name), "List"
service_override_config, action_name.lower(), "List"
)
elif str.lower(provided_access_level) == str.lower("Permissions management"):
elif provided_access_level_lower == "permissions management":
override_decision = override_access_level(
service_override_config, str.lower(action_name), "Permissions management"
service_override_config, action_name.lower(), "Permissions management"
)
elif str.lower(provided_access_level) == str.lower("Tagging"):
elif provided_access_level_lower == "tagging":
override_decision = override_access_level(
service_override_config, str.lower(action_name), "Tagging"
service_override_config, action_name.lower(), "Tagging"
)
else:
logger.debug(
"Unknown error - determine_override_status() can't determine the access level of %s:%s during "
"the scraping process. The provided access level was %s. Exiting...",
service,
str.lower(action_name),
action_name.lower(),
provided_access_level,
)
sys.exit()
Expand Down
25 changes: 11 additions & 14 deletions policy_sentry/util/actions.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,41 @@
"""
Text operations specific to IAM actions
"""
from __future__ import annotations


def get_service_from_action(action):
def get_service_from_action(action: str) -> str:
"""
Returns the service name from a service:action combination
:param action: ec2:DescribeInstance
:return: ec2
"""
service, action_name = action.split(":") # pylint: disable=unused-variable
return str.lower(service)
service = action.split(":")[0]
return service.lower()


def get_action_name_from_action(action):
def get_action_name_from_action(action: str) -> str:
"""
Returns the lowercase action name from a service:action combination
:param action: ec2:DescribeInstance
:return: describeinstance
"""
service, action_name = action.split(":") # pylint: disable=unused-variable
return str.lower(action_name)
action_name = action.split(":")[1]
return action_name.lower()


def get_lowercase_action_list(action_list):
def get_lowercase_action_list(action_list: list[str]) -> list[str]:
"""
Given a list of actions, return the list but in lowercase format
"""
new_action_list = []
for action in action_list:
new_action_list.append(str.lower(action))
return new_action_list
return [action.lower() for action in action_list]


def get_full_action_name(service, action_name):
def get_full_action_name(service: str, action_name: str) -> str:
"""
Gets the proper formatting for an action - the service, plus colon, plus action name.
:param service: service name, like s3
:param action_name: action name, like createbucket
:return: the resulting string
"""
action = service + ":" + action_name
return action
return f"{service}:{action_name}"
13 changes: 6 additions & 7 deletions policy_sentry/util/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,22 @@ def translate_condition_key_data_types(condition_str: str) -> str:
raise Exception(f"Unknown data format: {condition_lowercase}")


def get_service_from_condition_key(condition_key):
def get_service_from_condition_key(condition_key: str) -> str:
"""Given a condition key, return the service prefix"""
elements = condition_key.split(":", 2)
return elements[0]
return condition_key.split(":", 2)[0]


def get_comma_separated_condition_keys(condition_keys):
def get_comma_separated_condition_keys(condition_keys: str) -> str:
"""
:param condition_keys: String containing multiple condition keys, separated by double spaces
:return: result: String containing multiple condition keys, comma-separated
"""

result = condition_keys.replace(" ", ",") # replace the double spaces with a comma
return result
# replace the double spaces with a comma
return condition_keys.replace(" ", ",")


def is_condition_key_match(document_key, some_str):
def is_condition_key_match(document_key: str, some_str: str) -> bool:
""" Given a documented condition key and one from a policy, determine if they match
Examples:
- s3:prefix and s3:prefix obviously match
Expand Down
9 changes: 7 additions & 2 deletions policy_sentry/util/file.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""
Functions that relate to manipulating files, loading files, and managing filepaths.
"""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, cast

import yaml

logger = logging.getLogger(__name__)


def read_yaml_file(filename):
def read_yaml_file(filename: str | Path) -> dict[str, Any]:
"""
Reads a YAML file, safe loads, and returns the dictionary
Expand All @@ -16,7 +21,7 @@ def read_yaml_file(filename):
"""
with open(filename, "r") as yaml_file:
try:
cfg = yaml.safe_load(yaml_file)
cfg = cast("dict[str, Any]", yaml.safe_load(yaml_file))
except yaml.YAMLError as exc:
logger.critical(exc)
return cfg
7 changes: 3 additions & 4 deletions policy_sentry/util/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""


def capitalize_first_character(some_string):
def capitalize_first_character(some_string: str) -> str:
"""
Description: Capitalizes the first character of a string
:param some_string:
Expand All @@ -12,9 +12,8 @@ def capitalize_first_character(some_string):
return " ".join("".join([w[0].upper(), w[1:].lower()]) for w in some_string.split())


def strip_special_characters(some_string):
def strip_special_characters(some_string: str) -> str:
"""Remove all special characters, punctuation, and spaces from a string"""
# Input: "Special $#! characters spaces 888323"
# Output: 'Specialcharactersspaces888323'
result = ''.join(e for e in some_string if e.isalnum())
return result
return "".join(e for e in some_string if e.isalnum())
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[tool.mypy]
files = "policy_sentry"
strict = true

exclude = [
'^policy_sentry/analysis',
'^policy_sentry/bin',
'^policy_sentry/command',
'^policy_sentry/querying',
'^policy_sentry/shared',
'^policy_sentry/util/(arns|policy_files)',
'^policy_sentry/writing',
]
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ safety==2.3.5
bandit==1.7.5
# Formatting
black==22.12.0
# Type hints
mypy==1.4.1
types-pyyaml==6.0.12.11
# Other? Maybe this is from the docs? Not sure.
# appdirs==1.4.4
# astroid==2.5.6
Expand Down
15 changes: 15 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ def run_linter(c):
sys.exit(1)


# TEST - type check
@task
def run_mypy(c):
"""Type checking with `mypy`"""
try:
c.run('mypy policy_sentry/')
except UnexpectedExit as u_e:
logger.critical(f"FAIL! UnexpectedExit: {u_e}")
sys.exit(1)
except Failure as f_e:
logger.critical(f"FAIL: Failure: {f_e}")
sys.exit(1)


# UNIT TESTING
@task
def run_pytest(c):
Expand Down Expand Up @@ -282,6 +296,7 @@ def build_docker(c):
# test.add_task(run_full_test_suite, 'all')
test.add_task(format, 'format')
test.add_task(run_linter, 'lint')
test.add_task(run_mypy, 'type-check')
test.add_task(security_scan, 'security')

build.add_task(build_package, 'build-package')
Expand Down

0 comments on commit a4c973e

Please sign in to comment.