From 15b989077aa9c5185faa7891a71c32e93e8ee812 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Tue, 2 Jan 2024 21:12:49 +1100 Subject: [PATCH 01/15] [typing] Add mypy.ini and py.typed marker --- mypy.ini | 51 +++++++++++++++++++++++++++++++++++++++++++++ parsedmarc/py.typed | 0 2 files changed, 51 insertions(+) create mode 100644 mypy.ini create mode 100644 parsedmarc/py.typed diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..03056b65 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,51 @@ +[mypy] + +[mypy-kafka.*] +# https://github.com/dpkp/kafka-python/issues/2245 +ignore_missing_imports = True + +[mypy-imapclient.*] +# https://github.com/mjs/imapclient/issues/463 +ignore_missing_imports = True + +[mypy-elasticsearch_dsl.*] +# https://github.com/elastic/elasticsearch-dsl-py/issues/1533 +ignore_missing_imports = True + +[mypy-mailsuite.*] +# https://github.com/seanthegeek/mailsuite/issues/9 +ignore_missing_imports = True + +[mypy-msgraph.core] +# https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/446 +ignore_missing_imports = True + +[mypy-mailparser.*] +# https://github.com/SpamScope/mail-parser +ignore_missing_imports = True + +[mypy-expiringdict] +# https://github.com/mailgun/expiringdict +# uses inline typing which is not exposed +ignore_missing_imports = True + +[mypy-publicsuffixlist.*] +# https://github.com/ko-zu/psl +# supports python 2.6 so no inline type hints +ignore_missing_imports = True + +[mypy-google_auth_oauthlib.*] +# https://github.com/cffnpwr/google-auth-oauthlib-stubs/issues/1 +ignore_missing_imports = True + +# pip install the following: +# lxml-stubs +# types-python-dateutil +# types-requests +# types-tqdm +# types-xmltodict +# google-api-python-client-stubs +# google-auth-stubs +# types-boto3 + +# importlib_resources https://github.com/python/importlib_resources/blob/main/importlib_resources/py.typed:23 diff --git a/parsedmarc/py.typed b/parsedmarc/py.typed new file mode 100644 index 00000000..e69de29b From 838277ddb0fb216c58ce679f105b9b829664040b Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Wed, 3 Jan 2024 01:06:02 +1100 Subject: [PATCH 02/15] [typing] Initial pass at typing --- parsedmarc/__init__.py | 540 ++++++++++++++------------ parsedmarc/loganalytics.py | 2 +- parsedmarc/mail/gmail.py | 45 ++- parsedmarc/mail/graph.py | 42 +- parsedmarc/mail/mailbox_connection.py | 18 +- parsedmarc/utils.py | 152 ++++---- pyproject.toml | 3 + 7 files changed, 452 insertions(+), 350 deletions(-) diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 8584def9..af26423d 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -19,7 +19,7 @@ from csv import DictWriter from datetime import datetime from io import BytesIO, StringIO -from typing import Callable +from typing import List, Dict, Any, Optional, Union, Callable, BinaryIO, cast import mailparser import xmltodict @@ -66,25 +66,31 @@ class InvalidForensicReport(InvalidDMARCReport): """Raised when an invalid DMARC forensic report is encountered""" -def _parse_report_record(record, ip_db_path=None, offline=False, - nameservers=None, dns_timeout=2.0, parallel=False): +def _parse_report_record( + record: OrderedDict, + ip_db_path: Optional[str] = None, + offline: bool = False, + nameservers: Optional[List[str]] = None, + dns_timeout: float = 2.0, + parallel: bool = False +) -> OrderedDict: """ Converts a record from a DMARC aggregate report into a more consistent format Args: - record (OrderedDict): The record to convert - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not query online for geolocation or DNS - nameservers (list): A list of one or more nameservers to use + record: The record to convert + ip_db_path: Path to a MMDB file from MaxMind or DBIP + offline: Do not query online for geolocation or DNS + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - dns_timeout (float): Sets the DNS timeout in seconds + dns_timeout: Sets the DNS timeout in seconds Returns: - OrderedDict: The converted record + The converted record """ record = record.copy() - new_record = OrderedDict() + new_record: OrderedDict[str, Any] = OrderedDict() new_record_source = get_ip_address_info(record["row"]["source_ip"], cache=IP_ADDRESS_CACHE, ip_db_path=ip_db_path, @@ -102,7 +108,7 @@ def _parse_report_record(record, ip_db_path=None, offline=False, ]) if "disposition" in policy_evaluated: new_policy_evaluated["disposition"] = policy_evaluated["disposition"] - if new_policy_evaluated["disposition"].strip().lower() == "pass": + if cast(str, new_policy_evaluated["disposition"]).strip().lower() == "pass": new_policy_evaluated["disposition"] = "none" if "dkim" in policy_evaluated: new_policy_evaluated["dkim"] = policy_evaluated["dkim"] @@ -204,23 +210,29 @@ def _parse_report_record(record, ip_db_path=None, offline=False, return new_record -def parse_aggregate_report_xml(xml, ip_db_path=None, offline=False, - nameservers=None, timeout=2.0, - parallel=False, keep_alive=None): +def parse_aggregate_report_xml( + xml: str, + ip_db_path: Optional[str] = None, + offline: bool = False, + nameservers: Optional[List[str]] = None, + timeout: float = 2.0, + parallel: bool = False, + keep_alive: Optional[Callable] = None +) -> OrderedDict[str, Any]: """Parses a DMARC XML report string and returns a consistent OrderedDict Args: - xml (str): A string of DMARC aggregate report XML - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not query online for geolocation or DNS - nameservers (list): A list of one or more nameservers to use + xml: A string of DMARC aggregate report XML + ip_db_path: Path to a MMDB file from MaxMind or DBIP + offline: Do not query online for geolocation or DNS + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - timeout (float): Sets the DNS timeout in seconds - parallel (bool): Parallel processing - keep_alive (callable): Keep alive function + timeout: Sets the DNS timeout in seconds + parallel: Parallel processing + keep_alive: Keep alive function Returns: - OrderedDict: The parsed aggregate DMARC report + The parsed aggregate DMARC report """ errors = [] # Parse XML and recover from errors @@ -249,8 +261,8 @@ def parse_aggregate_report_xml(xml, ip_db_path=None, offline=False, schema = "draft" if "version" in report: schema = report["version"] - new_report = OrderedDict([("xml_schema", schema)]) - new_report_metadata = OrderedDict() + new_report: OrderedDict[str, Any] = OrderedDict([("xml_schema", schema)]) + new_report_metadata: OrderedDict[str, Any] = OrderedDict() if report_metadata["org_name"] is None: if report_metadata["email"] is not None: report_metadata["org_name"] = report_metadata[ @@ -368,7 +380,7 @@ def parse_aggregate_report_xml(xml, ip_db_path=None, offline=False, "Unexpected error: {0}".format(error.__str__())) -def extract_xml(input_): +def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: """ Extracts xml from a zip or gzip file at the given path, file-like object, or bytes. @@ -377,27 +389,31 @@ def extract_xml(input_): input_: A path to a file, a file like object, or bytes Returns: - str: The extracted XML + The extracted XML """ + file_object: BinaryIO try: - if type(input_) is str: + if isinstance(input_, str): file_object = open(input_, "rb") - elif type(input_) is bytes: + elif isinstance(input_, bytes): file_object = BytesIO(input_) else: file_object = input_ header = file_object.read(6) file_object.seek(0) + if header.startswith(MAGIC_ZIP): _zip = zipfile.ZipFile(file_object) xml = _zip.open(_zip.namelist()[0]).read().decode(errors='ignore') + elif header.startswith(MAGIC_GZIP): - xml = zlib.decompress(file_object.getvalue(), - zlib.MAX_WBITS | 16).decode(errors='ignore') + xml = zlib.decompress(file_object.read(), zlib.MAX_WBITS | 16).decode(errors='ignore') + elif header.startswith(MAGIC_XML): xml = file_object.read().decode(errors='ignore') + else: file_object.close() raise InvalidAggregateReport("Not a valid zip, gzip, or xml file") @@ -418,39 +434,45 @@ def extract_xml(input_): return xml -def parse_aggregate_report_file(_input, offline=False, ip_db_path=None, - nameservers=None, - dns_timeout=2.0, - parallel=False, - keep_alive=None): +def parse_aggregate_report_file( + _input: Union[bytes, str, BinaryIO], + offline: bool =False, + ip_db_path: Optional[str] = None, + nameservers: Optional[List[str]] = None, + dns_timeout: float = 2.0, + parallel: bool = False, + keep_alive: Optional[Callable] = None +) -> OrderedDict[str, Any]: """Parses a file at the given path, a file-like object. or bytes as an aggregate DMARC report Args: _input: A path to a file, a file like object, or bytes - offline (bool): Do not query online for geolocation or DNS - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - nameservers (list): A list of one or more nameservers to use + offline: Do not query online for geolocation or DNS + ip_db_path: Path to a MMDB file from MaxMind or DBIP + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - dns_timeout (float): Sets the DNS timeout in seconds - parallel (bool): Parallel processing - keep_alive (callable): Keep alive function + dns_timeout: Sets the DNS timeout in seconds + parallel: Parallel processing + keep_alive: Keep alive function Returns: - OrderedDict: The parsed DMARC aggregate report + The parsed DMARC aggregate report """ xml = extract_xml(_input) - return parse_aggregate_report_xml(xml, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - timeout=dns_timeout, - parallel=parallel, - keep_alive=keep_alive) + return parse_aggregate_report_xml( + xml, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + timeout=dns_timeout, + parallel=parallel, + keep_alive=keep_alive, + ) -def parsed_aggregate_reports_to_csv_rows(reports): +def parsed_aggregate_reports_to_csv_rows(reports: Union[OrderedDict, List[OrderedDict]]) -> List[Dict[str, Union[str, int, bool]]]: """ Converts one or more parsed aggregate reports to list of dicts in flat CSV format @@ -459,7 +481,7 @@ def parsed_aggregate_reports_to_csv_rows(reports): reports: A parsed aggregate report or list of parsed aggregate reports Returns: - list: Parsed aggregate report data as a list of dicts in flat CSV + Parsed aggregate report data as a list of dicts in flat CSV format """ @@ -553,7 +575,7 @@ def to_str(obj): return rows -def parsed_aggregate_reports_to_csv(reports): +def parsed_aggregate_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict]]) -> str: """ Converts one or more parsed aggregate reports to flat CSV format, including headers @@ -562,7 +584,7 @@ def parsed_aggregate_reports_to_csv(reports): reports: A parsed aggregate report or list of parsed aggregate reports Returns: - str: Parsed aggregate report data in flat CSV format, including headers + Parsed aggregate report data in flat CSV format, including headers """ fields = ["xml_schema", "org_name", "org_email", @@ -589,29 +611,35 @@ def parsed_aggregate_reports_to_csv(reports): return csv_file_object.getvalue() -def parse_forensic_report(feedback_report, sample, msg_date, - offline=False, ip_db_path=None, - nameservers=None, dns_timeout=2.0, - strip_attachment_payloads=False, - parallel=False): +def parse_forensic_report( + feedback_report: str, + sample: str, + msg_date: datetime, + offline: bool = False, + ip_db_path: Optional[str] = None, + nameservers: Optional[List[str]] = None, + dns_timeout: float = 2.0, + strip_attachment_payloads: bool = False, + parallel: bool = False +) -> OrderedDict: """ Converts a DMARC forensic report and sample to a ``OrderedDict`` Args: - feedback_report (str): A message's feedback report as a string - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not query online for geolocation or DNS - sample (str): The RFC 822 headers or RFC 822 message sample - msg_date (str): The message's date header + feedback_report: A message's feedback report as a string + sample: The RFC 822 headers or RFC 822 message sample + msg_date: The message's date header + offline: Do not query online for geolocation or DNS + ip_db_path: Path to a MMDB file from MaxMind or DBIP nameservers (list): A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - dns_timeout (float): Sets the DNS timeout in seconds - strip_attachment_payloads (bool): Remove attachment payloads from + dns_timeout: Sets the DNS timeout in seconds + strip_attachment_payloads: Remove attachment payloads from forensic report results - parallel (bool): Parallel processing + parallel: Parallel processing Returns: - OrderedDict: A parsed report and sample + A parsed report and sample """ delivery_results = ["delivered", "spam", "policy", "reject", "other"] @@ -644,10 +672,8 @@ def parse_forensic_report(feedback_report, sample, msg_date, if parsed_report["delivery_result"] not in delivery_results: parsed_report["delivery_result"] = "other" - arrival_utc = human_timestamp_to_datetime( - parsed_report["arrival_date"], to_utc=True) - arrival_utc = arrival_utc.strftime("%Y-%m-%d %H:%M:%S") - parsed_report["arrival_date_utc"] = arrival_utc + parsed_report["arrival_date_utc"] = human_timestamp_to_datetime( + parsed_report["arrival_date"], to_utc=True).strftime("%Y-%m-%d %H:%M:%S") ip_address = re.split(r'\s', parsed_report["source_ip"]).pop(0) parsed_report_source = get_ip_address_info(ip_address, @@ -711,7 +737,7 @@ def parse_forensic_report(feedback_report, sample, msg_date, "Unexpected error: {0}".format(error.__str__())) -def parsed_forensic_reports_to_csv_rows(reports): +def parsed_forensic_reports_to_csv_rows(reports: Union[OrderedDict, List[OrderedDict]]) -> List[Dict[str, Any]]: """ Converts one or more parsed forensic reports to a list of dicts in flat CSV format @@ -720,7 +746,7 @@ def parsed_forensic_reports_to_csv_rows(reports): reports: A parsed forensic report or list of parsed forensic reports Returns: - list: Parsed forensic report data as a list of dicts in flat CSV format + Parsed forensic report data as a list of dicts in flat CSV format """ if type(reports) is OrderedDict: reports = [reports] @@ -746,7 +772,7 @@ def parsed_forensic_reports_to_csv_rows(reports): return rows -def parsed_forensic_reports_to_csv(reports): +def parsed_forensic_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict]]) -> str: """ Converts one or more parsed forensic reports to flat CSV format, including headers @@ -755,7 +781,7 @@ def parsed_forensic_reports_to_csv(reports): reports: A parsed forensic report or list of parsed forensic reports Returns: - str: Parsed forensic report data in flat CSV format, including headers + Parsed forensic report data in flat CSV format, including headers """ fields = ["feedback_type", "user_agent", "version", "original_envelope_id", "original_mail_from", "original_rcpt_to", "arrival_date", @@ -772,7 +798,7 @@ def parsed_forensic_reports_to_csv(reports): rows = parsed_forensic_reports_to_csv_rows(reports) for row in rows: - new_row = {} + new_row: Dict[str, Any] = {} for key in new_row.keys(): new_row[key] = row[key] csv_writer.writerow(new_row) @@ -780,42 +806,44 @@ def parsed_forensic_reports_to_csv(reports): return csv_file.getvalue() -def parse_report_email(input_, offline=False, ip_db_path=None, - nameservers=None, dns_timeout=2.0, - strip_attachment_payloads=False, - parallel=False, keep_alive=None): +def parse_report_email( + input_: Union[bytes, str], + offline: bool = False, + ip_db_path: Optional[str] = None, + nameservers: Optional[List[str]] = None, + dns_timeout: float = 2.0, + strip_attachment_payloads: bool = False, + parallel: bool = False, + keep_alive: Optional[Callable] = None, +) -> OrderedDict[str, Union[str, OrderedDict]]: """ Parses a DMARC report from an email Args: input_: An emailed DMARC report in RFC 822 format, as bytes or a string - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not query online for geolocation on DNS - nameservers (list): A list of one or more nameservers to use - dns_timeout (float): Sets the DNS timeout in seconds - strip_attachment_payloads (bool): Remove attachment payloads from + offline: Do not query online for geolocation on DNS + ip_db_path: Path to a MMDB file from MaxMind or DBIP + nameservers: A list of one or more nameservers to use + dns_timeout: Sets the DNS timeout in seconds + strip_attachment_payloads: Remove attachment payloads from forensic report results - parallel (bool): Parallel processing - keep_alive (callable): keep alive function + parallel: Parallel processing + keep_alive: keep alive function Returns: - OrderedDict: - * ``report_type``: ``aggregate`` or ``forensic`` - * ``report``: The parsed report + Dictionary of `{"report_type": "aggregate" or "forensic", "report": report}` """ - result = None - try: - if is_outlook_msg(input_): + if isinstance(input_, bytes) and is_outlook_msg(input_): input_ = convert_outlook_msg(input_) - if type(input_) is bytes: + if isinstance(input_, bytes): input_ = input_.decode(encoding="utf8", errors="replace") msg = mailparser.parse_from_string(input_) msg_headers = json.loads(msg.headers_json) - date = email.utils.format_datetime(datetime.utcnow()) if "Date" in msg_headers: - date = human_timestamp_to_datetime( - msg_headers["Date"]) + date = human_timestamp_to_datetime(msg_headers["Date"]) + else: + date = datetime.utcnow() msg = email.message_from_string(input_) except Exception as e: @@ -880,9 +908,7 @@ def parse_report_email(input_, offline=False, ip_db_path=None, dns_timeout=dns_timeout, parallel=parallel, keep_alive=keep_alive) - result = OrderedDict([("report_type", "aggregate"), - ("report", aggregate_report)]) - return result + return OrderedDict([("report_type", "aggregate"), ("report", aggregate_report)]) except (TypeError, ValueError, binascii.Error): pass @@ -917,99 +943,112 @@ def parse_report_email(input_, offline=False, ip_db_path=None, except Exception as e: raise InvalidForensicReport(e.__str__()) - result = OrderedDict([("report_type", "forensic"), - ("report", forensic_report)]) - return result + return OrderedDict([("report_type", "forensic"), ("report", forensic_report)]) - if result is None: - error = 'Message with subject "{0}" is ' \ - 'not a valid DMARC report'.format(subject) - raise InvalidDMARCReport(error) + error = 'Message with subject "{0}" is not a valid DMARC report'.format(subject) + raise InvalidDMARCReport(error) -def parse_report_file(input_, nameservers=None, dns_timeout=2.0, - strip_attachment_payloads=False, ip_db_path=None, - offline=False, parallel=False, keep_alive=None): +def parse_report_file( + input_: Union[str, bytes, BinaryIO], + nameservers: Optional[List[str]] = None, + dns_timeout: float = 2.0, + strip_attachment_payloads: bool = False, + ip_db_path: Optional[str] = None, + offline: bool = False, + parallel: bool = False, + keep_alive: Optional[Callable] = None +) -> OrderedDict: """Parses a DMARC aggregate or forensic file at the given path, a file-like object. or bytes Args: input_: A path to a file, a file like object, or bytes - nameservers (list): A list of one or more nameservers to use + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - dns_timeout (float): Sets the DNS timeout in seconds - strip_attachment_payloads (bool): Remove attachment payloads from + dns_timeout: Sets the DNS timeout in seconds + strip_attachment_payloads: Remove attachment payloads from forensic report results - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not make online queries for geolocation or DNS - parallel (bool): Parallel processing - keep_alive (callable): Keep alive function + ip_db_path: Path to a MMDB file from MaxMind or DBIP + offline: Do not make online queries for geolocation or DNS + parallel: Parallel processing + keep_alive: Keep alive function Returns: - OrderedDict: The parsed DMARC report + The parsed DMARC report """ - if type(input_) is str: + file_object: BinaryIO + if isinstance(input_, str): logger.debug("Parsing {0}".format(input_)) file_object = open(input_, "rb") - elif type(input_) is bytes: + elif isinstance(input_, bytes): file_object = BytesIO(input_) else: file_object = input_ content = file_object.read() file_object.close() + + results: OrderedDict[str, Union[str, OrderedDict]] try: - report = parse_aggregate_report_file(content, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - dns_timeout=dns_timeout, - parallel=parallel, - keep_alive=keep_alive) - results = OrderedDict([("report_type", "aggregate"), - ("report", report)]) + report = parse_aggregate_report_file( + content, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + dns_timeout=dns_timeout, + parallel=parallel, + keep_alive=keep_alive, + ) + results = OrderedDict([("report_type", "aggregate"),("report", report)]) + except InvalidAggregateReport: try: - sa = strip_attachment_payloads - results = parse_report_email(content, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - dns_timeout=dns_timeout, - strip_attachment_payloads=sa, - parallel=parallel, - keep_alive=keep_alive) + results = parse_report_email( + content, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + dns_timeout=dns_timeout, + strip_attachment_payloads=strip_attachment_payloads, + parallel=parallel, + keep_alive=keep_alive, + ) except InvalidDMARCReport: raise InvalidDMARCReport("Not a valid aggregate or forensic " "report") return results -def get_dmarc_reports_from_mbox(input_, nameservers=None, dns_timeout=2.0, - strip_attachment_payloads=False, - ip_db_path=None, - offline=False, - parallel=False): +def get_dmarc_reports_from_mbox( + input_: str, + nameservers: Optional[List[str]] = None, + dns_timeout: float = 2.0, + strip_attachment_payloads: bool = False, + ip_db_path: Optional[str] = None, + offline: bool = False, + parallel: bool = False +) -> OrderedDict[str, List[OrderedDict]]: """Parses a mailbox in mbox format containing e-mails with attached DMARC reports Args: input_: A path to a mbox file - nameservers (list): A list of one or more nameservers to use + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - dns_timeout (float): Sets the DNS timeout in seconds - strip_attachment_payloads (bool): Remove attachment payloads from + dns_timeout: Sets the DNS timeout in seconds + strip_attachment_payloads: Remove attachment payloads from forensic report results - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not make online queries for geolocation or DNS - parallel (bool): Parallel processing + ip_db_path: Path to a MMDB file from MaxMind or DBIP + offline: Do not make online queries for geolocation or DNS + parallel: Parallel processing Returns: - OrderedDict: Lists of ``aggregate_reports`` and ``forensic_reports`` + Dictionary of Lists of ``aggregate_reports`` and ``forensic_reports`` """ - aggregate_reports = [] - forensic_reports = [] + aggregate_reports: List[OrderedDict] = [] + forensic_reports: List[OrderedDict] = [] try: mbox = mailbox.mbox(input_) message_keys = mbox.keys() @@ -1032,9 +1071,9 @@ def get_dmarc_reports_from_mbox(input_, nameservers=None, dns_timeout=2.0, strip_attachment_payloads=sa, parallel=parallel) if parsed_email["report_type"] == "aggregate": - aggregate_reports.append(parsed_email["report"]) + aggregate_reports.append(cast(OrderedDict[Any, Any], parsed_email["report"])) elif parsed_email["report_type"] == "forensic": - forensic_reports.append(parsed_email["report"]) + forensic_reports.append(cast(OrderedDict[Any, Any], parsed_email["report"])) except InvalidDMARCReport as error: logger.warning(error.__str__()) except mailbox.NoSuchMailboxError: @@ -1043,19 +1082,21 @@ def get_dmarc_reports_from_mbox(input_, nameservers=None, dns_timeout=2.0, ("forensic_reports", forensic_reports)]) -def get_dmarc_reports_from_mailbox(connection: MailboxConnection, - reports_folder="INBOX", - archive_folder="Archive", - delete=False, - test=False, - ip_db_path=None, - offline=False, - nameservers=None, - dns_timeout=6.0, - strip_attachment_payloads=False, - results=None, - batch_size=10, - create_folders=True): +def get_dmarc_reports_from_mailbox( + connection: MailboxConnection, + reports_folder: str = "INBOX", + archive_folder: str = "Archive", + delete: bool = False, + test: bool = False, + ip_db_path: Optional[str] = None, + offline: bool = False, + nameservers: Optional[List[str]] = None, + dns_timeout: float = 6.0, + strip_attachment_payloads: bool = False, + results: Optional[OrderedDict[str, List[OrderedDict]]] = None, + batch_size: int = 10, + create_folders: bool = True, +) -> OrderedDict[str, List[OrderedDict]]: """ Fetches and parses DMARC reports from a mailbox @@ -1063,18 +1104,18 @@ def get_dmarc_reports_from_mailbox(connection: MailboxConnection, connection: A Mailbox connection object reports_folder: The folder where reports can be found archive_folder: The folder to move processed mail to - delete (bool): Delete messages after processing them - test (bool): Do not move or delete messages after processing them - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not query online for geolocation or DNS - nameservers (list): A list of DNS nameservers to query - dns_timeout (float): Set the DNS query timeout - strip_attachment_payloads (bool): Remove attachment payloads from + delete: Delete messages after processing them + test: Do not move or delete messages after processing them + ip_db_path: Path to a MMDB file from MaxMind or DBIP + offline: Do not query online for geolocation or DNS + nameservers: A list of DNS nameservers to query + dns_timeout: Set the DNS query timeout + strip_attachment_payloads: Remove attachment payloads from forensic report results - results (dict): Results from the previous run - batch_size (int): Number of messages to read and process before saving + results: Results from the previous run + batch_size: Number of messages to read and process before saving (use 0 for no limit) - create_folders (bool): Whether to create the destination folders + create_folders: Whether to create the destination folders (not used in watch) Returns: @@ -1132,10 +1173,10 @@ def get_dmarc_reports_from_mailbox(connection: MailboxConnection, strip_attachment_payloads=sa, keep_alive=connection.keepalive) if parsed_email["report_type"] == "aggregate": - aggregate_reports.append(parsed_email["report"]) + aggregate_reports.append(cast(OrderedDict[Any, Any], parsed_email["report"])) aggregate_report_msg_uids.append(msg_uid) elif parsed_email["report_type"] == "forensic": - forensic_reports.append(parsed_email["report"]) + forensic_reports.append(cast(OrderedDict[Any, Any], parsed_email["report"])) forensic_report_msg_uids.append(msg_uid) except InvalidDMARCReport as error: logger.warning(error.__str__()) @@ -1165,9 +1206,8 @@ def get_dmarc_reports_from_mailbox(connection: MailboxConnection, connection.delete_message(msg_uid) except Exception as e: - message = "Error deleting message UID" - e = "{0} {1}: " "{2}".format(message, msg_uid, e) - logger.error("Mailbox error: {0}".format(e)) + message = "Mailbox error: Error deleting message UID {0}: {1}".format(msg_uid, repr(e)) + logger.error(message) else: if len(aggregate_report_msg_uids) > 0: log_message = "Moving aggregate report messages from" @@ -1185,9 +1225,8 @@ def get_dmarc_reports_from_mailbox(connection: MailboxConnection, connection.move_message(msg_uid, aggregate_reports_folder) except Exception as e: - message = "Error moving message UID" - e = "{0} {1}: {2}".format(message, msg_uid, e) - logger.error("Mailbox error: {0}".format(e)) + message = "Mailbox error: Error moving message UID {0}: {1}".format(msg_uid, repr(e)) + logger.error(message) if len(forensic_report_msg_uids) > 0: message = "Moving forensic report messages from" logger.debug( @@ -1205,9 +1244,9 @@ def get_dmarc_reports_from_mailbox(connection: MailboxConnection, connection.move_message(msg_uid, forensic_reports_folder) except Exception as e: - e = "Error moving message UID {0}: {1}".format( - msg_uid, e) - logger.error("Mailbox error: {0}".format(e)) + message = "Mailbox error: Error moving message UID {0}: {1}".format( + msg_uid, repr(e)) + logger.error(message) results = OrderedDict([("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)]) @@ -1232,14 +1271,21 @@ def get_dmarc_reports_from_mailbox(connection: MailboxConnection, return results -def watch_inbox(mailbox_connection: MailboxConnection, - callback: Callable, - reports_folder="INBOX", - archive_folder="Archive", delete=False, test=False, - check_timeout=30, ip_db_path=None, - offline=False, nameservers=None, - dns_timeout=6.0, strip_attachment_payloads=False, - batch_size=None): +def watch_inbox( + mailbox_connection: MailboxConnection, + callback: Callable, + reports_folder: str = "INBOX", + archive_folder: str = "Archive", + delete: bool = False, + test: bool = False, + check_timeout: int = 30, + ip_db_path: Optional[str] = None, + offline: bool = False, + nameservers: Optional[List[str]] = None, + dns_timeout: float = 6.0, + strip_attachment_payloads: bool = False, + batch_size: Optional[int] = None +) -> None: """ Watches the mailbox for new messages and sends the results to a callback function @@ -1249,18 +1295,18 @@ def watch_inbox(mailbox_connection: MailboxConnection, callback: The callback function to receive the parsing results reports_folder: The IMAP folder where reports can be found archive_folder: The folder to move processed mail to - delete (bool): Delete messages after processing them - test (bool): Do not move or delete messages after processing them - check_timeout (int): Number of seconds to wait for a IMAP IDLE response + delete: Delete messages after processing them + test: Do not move or delete messages after processing them + check_timeout: Number of seconds to wait for a IMAP IDLE response or the number of seconds until the next mail check - ip_db_path (str): Path to a MMDB file from MaxMind or DBIP - offline (bool): Do not query online for geolocation or DNS - nameservers (list): A list of one or more nameservers to use + ip_db_path: Path to a MMDB file from MaxMind or DBIP + offline: Do not query online for geolocation or DNS + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - dns_timeout (float): Set the DNS query timeout - strip_attachment_payloads (bool): Replace attachment payloads in + dns_timeout: Set the DNS query timeout + strip_attachment_payloads: Replace attachment payloads in forensic report samples with None - batch_size (int): Number of messages to read and process before saving + batch_size: Number of messages to read and process before saving """ def check_callback(connection): @@ -1283,7 +1329,7 @@ def check_callback(connection): check_timeout=check_timeout) -def append_json(filename, reports): +def append_json(filename: str, reports: List[OrderedDict]) -> None: with open(filename, "a+", newline="\n", encoding="utf-8") as output: output_json = json.dumps(reports, ensure_ascii=False, indent=2) if output.seek(0, os.SEEK_END) != 0: @@ -1304,9 +1350,10 @@ def append_json(filename, reports): output.truncate() output.write(output_json) + return -def append_csv(filename, csv): +def append_csv(filename: str, csv: str) -> None: with open(filename, "a+", newline="\n", encoding="utf-8") as output: if output.seek(0, os.SEEK_END) != 0: # strip the headers from the CSV @@ -1316,23 +1363,27 @@ def append_csv(filename, csv): # append it correctly return output.write(csv) + return -def save_output(results, output_directory="output", - aggregate_json_filename="aggregate.json", - forensic_json_filename="forensic.json", - aggregate_csv_filename="aggregate.csv", - forensic_csv_filename="forensic.csv"): +def save_output( + results: OrderedDict[str, List[OrderedDict]], + output_directory: str = "output", + aggregate_json_filename: str = "aggregate.json", + forensic_json_filename: str = "forensic.json", + aggregate_csv_filename: str = "aggregate.csv", + forensic_csv_filename: str = "forensic.csv" +) -> None: """ Save report data in the given directory Args: - results (OrderedDict): Parsing results - output_directory (str): The path to the directory to save in - aggregate_json_filename (str): Filename for the aggregate JSON file - forensic_json_filename (str): Filename for the forensic JSON file - aggregate_csv_filename (str): Filename for the aggregate CSV file - forensic_csv_filename (str): Filename for the forensic CSV file + results: Parsing results + output_directory: The path to the directory to save in + aggregate_json_filename: Filename for the aggregate JSON file + forensic_json_filename: Filename for the forensic JSON file + aggregate_csv_filename: Filename for the aggregate CSV file + forensic_csv_filename: Filename for the forensic CSV file """ aggregate_reports = results["aggregate_reports"] @@ -1378,17 +1429,18 @@ def save_output(results, output_directory="output", path = os.path.join(samples_directory, filename) with open(path, "w", newline="\n", encoding="utf-8") as sample_file: sample_file.write(sample) + return -def get_report_zip(results): +def get_report_zip(results: OrderedDict[str, List[OrderedDict]]) -> bytes: """ Creates a zip file of parsed report output Args: - results (OrderedDict): The parsed results + results: The parsed results Returns: - bytes: zip file bytes + raw zip file """ def add_subdir(root_path, subdir): subdir_path = os.path.join(root_path, subdir) @@ -1425,29 +1477,40 @@ def add_subdir(root_path, subdir): return storage.getvalue() -def email_results(results, host, mail_from, mail_to, - mail_cc=None, mail_bcc=None, port=0, - require_encryption=False, verify=True, - username=None, password=None, subject=None, - attachment_filename=None, message=None): +def email_results( + results: OrderedDict[str, List[OrderedDict]], + host: str, + mail_from: str, + mail_to: List[str], + mail_cc: Optional[List[str]] = None, + mail_bcc: Optional[List[str]] = None, + port: int = 0, + require_encryption: bool = False, + verify: bool = True, + username: Optional[str] = None, + password: Optional[str] = None, + subject: Optional[str] = None, + attachment_filename: Optional[str] = None, + message: Optional[str] = None +) -> None: """ Emails parsing results as a zip file Args: - results (OrderedDict): Parsing results + results: Parsing results host: Mail server hostname or IP address mail_from: The value of the message from header - mail_to (list): A list of addresses to mail to - mail_cc (list): A list of addresses to CC - mail_bcc (list): A list addresses to BCC - port (int): Port to use - require_encryption (bool): Require a secure connection from the start - verify (bool): verify the SSL/TLS certificate - username (str): An optional username - password (str): An optional password - subject (str): Overrides the default message subject - attachment_filename (str): Override the default attachment filename - message (str): Override the default plain text body + mail_to: A list of addresses to mail to + mail_cc: A list of addresses to CC + mail_bcc: A list addresses to BCC + port: Port to use + require_encryption: Require a secure connection from the start + verify: verify the SSL/TLS certificate + username: An optional username + password: An optional password + subject: Overrides the default message subject + attachment_filename: Override the default attachment filename + message: Override the default plain text body """ logger.debug("Emailing report to: {0}".format(",".join(mail_to))) date_string = datetime.now().strftime("%Y-%m-%d") @@ -1472,3 +1535,4 @@ def email_results(results, host, mail_from, mail_to, require_encryption=require_encryption, verify=verify, username=username, password=password, subject=subject, attachments=attachments, plain_message=message) + return None diff --git a/parsedmarc/loganalytics.py b/parsedmarc/loganalytics.py index 9ca1496c..78e018e6 100644 --- a/parsedmarc/loganalytics.py +++ b/parsedmarc/loganalytics.py @@ -107,7 +107,7 @@ def publish_json( The stream name where the DMARC reports needs to be pushed. """ try: - logs_client.upload(self.conf.dcr_immutable_id, dcr_stream, results) + logs_client.upload(self.conf.dcr_immutable_id, dcr_stream, results) # type: ignore[attr-defined] except HttpResponseError as e: raise LogAnalyticsException( "Upload failed: {error}" diff --git a/parsedmarc/mail/gmail.py b/parsedmarc/mail/gmail.py index 92ec431d..40f61af4 100644 --- a/parsedmarc/mail/gmail.py +++ b/parsedmarc/mail/gmail.py @@ -2,7 +2,7 @@ from functools import lru_cache from pathlib import Path from time import sleep -from typing import List +from typing import TYPE_CHECKING, List from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials @@ -10,6 +10,10 @@ from googleapiclient.discovery import build from googleapiclient.errors import HttpError +if TYPE_CHECKING: + # https://github.com/henribru/google-api-python-client-stubs?tab=readme-ov-file#explicit-annotations + from googleapiclient._apis.gmail.v1.schemas import ModifyMessageRequest, Label + from parsedmarc.log import logger from parsedmarc.mail.mailbox_connection import MailboxConnection @@ -36,25 +40,27 @@ def _get_creds(token_file, credentials_file, scopes, oauth2_port): class GmailConnection(MailboxConnection): - def __init__(self, - token_file: str, - credentials_file: str, - scopes: List[str], - include_spam_trash: bool, - reports_folder: str, - oauth2_port: int): + def __init__( + self, + token_file: str, + credentials_file: str, + scopes: List[str], + include_spam_trash: bool, + reports_folder: str, + oauth2_port: int, + ): creds = _get_creds(token_file, credentials_file, scopes, oauth2_port) self.service = build('gmail', 'v1', credentials=creds) self.include_spam_trash = include_spam_trash self.reports_label_id = self._find_label_id_for_label(reports_folder) - def create_folder(self, folder_name: str): + def create_folder(self, folder_name: str) -> None: # Gmail doesn't support the name Archive if folder_name == 'Archive': return logger.debug(f"Creating label {folder_name}") - request_body = {'name': folder_name, 'messageListVisibility': 'show'} + request_body: "Label" = {'name': folder_name, 'messageListVisibility': 'show'} try: self.service.users().labels()\ .create(userId='me', body=request_body).execute() @@ -64,6 +70,7 @@ def create_folder(self, folder_name: str): f'skipping creation') else: raise e + return def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: reports_label_id = self._find_label_id_for_label(reports_folder) @@ -76,7 +83,7 @@ def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: messages = results.get('messages', []) return [message['id'] for message in messages] - def fetch_message(self, message_id): + def fetch_message(self, message_id: str) -> bytes: msg = self.service.users().messages()\ .get(userId='me', id=message_id, @@ -85,13 +92,14 @@ def fetch_message(self, message_id): .execute() return urlsafe_b64decode(msg['raw']) - def delete_message(self, message_id: str): + def delete_message(self, message_id: str) -> None: self.service.users().messages().delete(userId='me', id=message_id) + return def move_message(self, message_id: str, folder_name: str): label_id = self._find_label_id_for_label(folder_name) logger.debug(f"Moving message UID {message_id} to {folder_name}") - request_body = { + request_body: "ModifyMessageRequest" = { 'addLabelIds': [label_id], 'removeLabelIds': [self.reports_label_id] } @@ -100,16 +108,18 @@ def move_message(self, message_id: str, folder_name: str): id=message_id, body=request_body)\ .execute() + return - def keepalive(self): - # Not needed - pass + def keepalive(self) -> None: + # no action needed + return - def watch(self, check_callback, check_timeout): + def watch(self, check_callback, check_timeout) -> None: """ Checks the mailbox for new messages every n seconds""" while True: sleep(check_timeout) check_callback(self) + return @lru_cache(maxsize=10) def _find_label_id_for_label(self, label_name: str) -> str: @@ -118,3 +128,4 @@ def _find_label_id_for_label(self, label_name: str) -> str: for label in labels: if label_name == label['id'] or label_name == label['name']: return label['id'] + raise ValueError(f"Label {label_name} not found") diff --git a/parsedmarc/mail/graph.py b/parsedmarc/mail/graph.py index de565b2c..4387138a 100644 --- a/parsedmarc/mail/graph.py +++ b/parsedmarc/mail/graph.py @@ -2,7 +2,7 @@ from functools import lru_cache from pathlib import Path from time import sleep -from typing import List, Optional +from typing import Dict, List, Optional, Union, Any from azure.identity import UsernamePasswordCredential, \ DeviceCodeCredential, ClientSecretCredential, \ @@ -18,13 +18,15 @@ class AuthMethod(Enum): UsernamePassword = 2 ClientSecret = 3 +Credential = Union[DeviceCodeCredential, UsernamePasswordCredential, ClientSecretCredential] -def _get_cache_args(token_path: Path, allow_unencrypted_storage): - cache_args = { +def _get_cache_args(token_path: Path, allow_unencrypted_storage: bool): + cache_args: Dict[str, Any] = { 'cache_persistence_options': TokenCachePersistenceOptions( name='parsedmarc', - allow_unencrypted_storage=allow_unencrypted_storage) + allow_unencrypted_storage=allow_unencrypted_storage, + ) } auth_record = _load_token(token_path) if auth_record: @@ -46,7 +48,8 @@ def _cache_auth_record(record: AuthenticationRecord, token_path: Path): token_file.write(token) -def _generate_credential(auth_method: str, token_path: Path, **kwargs): +def _generate_credential(auth_method: str, token_path: Path, **kwargs) -> Credential: + credential: Credential if auth_method == AuthMethod.DeviceCode.name: credential = DeviceCodeCredential( client_id=kwargs['client_id'], @@ -57,7 +60,9 @@ def _generate_credential(auth_method: str, token_path: Path, **kwargs): token_path, allow_unencrypted_storage=kwargs['allow_unencrypted_storage']) ) - elif auth_method == AuthMethod.UsernamePassword.name: + return credential + + if auth_method == AuthMethod.UsernamePassword.name: credential = UsernamePasswordCredential( client_id=kwargs['client_id'], client_credential=kwargs['client_secret'], @@ -68,15 +73,16 @@ def _generate_credential(auth_method: str, token_path: Path, **kwargs): token_path, allow_unencrypted_storage=kwargs['allow_unencrypted_storage']) ) - elif auth_method == AuthMethod.ClientSecret.name: + return credential + + if auth_method == AuthMethod.ClientSecret.name: credential = ClientSecretCredential( client_id=kwargs['client_id'], tenant_id=kwargs['tenant_id'], client_secret=kwargs['client_secret'] ) - else: - raise RuntimeError(f'Auth method {auth_method} not found') - return credential + return credential + raise RuntimeError(f'Auth method {auth_method} not found') class MSGraphConnection(MailboxConnection): @@ -100,10 +106,10 @@ def __init__(self, tenant_id=tenant_id, token_path=token_path, allow_unencrypted_storage=allow_unencrypted_storage) - client_params = { + client_params: Dict[str, Any] = { 'credential': credential } - if not isinstance(credential, ClientSecretCredential): + if isinstance(credential, (DeviceCodeCredential, UsernamePasswordCredential)): scopes = ['Mail.ReadWrite'] # Detect if mailbox is shared if mailbox and username != mailbox: @@ -115,7 +121,7 @@ def __init__(self, self._client = GraphClient(**client_params) self.mailbox_name = mailbox - def create_folder(self, folder_name: str): + def create_folder(self, folder_name: str) -> None: sub_url = '' path_parts = folder_name.split('/') if len(path_parts) > 1: # Folder is a subFolder @@ -140,9 +146,9 @@ def create_folder(self, folder_name: str): logger.warning(f'Unknown response ' f'{resp.status_code} {resp.json()}') - def fetch_messages(self, folder_name: str, **kwargs) -> List[str]: + def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: """ Returns a list of message UIDs in the specified folder """ - folder_id = self._find_folder_id_from_folder_path(folder_name) + folder_id = self._find_folder_id_from_folder_path(reports_folder) url = f'/users/{self.mailbox_name}/mailFolders/' \ f'{folder_id}/messages' batch_size = kwargs.get('batch_size') @@ -151,9 +157,9 @@ def fetch_messages(self, folder_name: str, **kwargs) -> List[str]: emails = self._get_all_messages(url, batch_size) return [email['id'] for email in emails] - def _get_all_messages(self, url, batch_size): + def _get_all_messages(self, url: str, batch_size: int): messages: list - params = { + params: Dict[str, Any] = { '$select': 'id' } if batch_size and batch_size > 0: @@ -182,7 +188,7 @@ def mark_message_read(self, message_id: str): raise RuntimeWarning(f"Failed to mark message read" f"{resp.status_code}: {resp.json()}") - def fetch_message(self, message_id: str): + def fetch_message(self, message_id: str) -> str: url = f'/users/{self.mailbox_name}/messages/{message_id}/$value' result = self._client.get(url) if result.status_code != 200: diff --git a/parsedmarc/mail/mailbox_connection.py b/parsedmarc/mail/mailbox_connection.py index ecaa2cbb..e11fa379 100644 --- a/parsedmarc/mail/mailbox_connection.py +++ b/parsedmarc/mail/mailbox_connection.py @@ -1,30 +1,28 @@ from abc import ABC -from typing import List +from typing import List, Union class MailboxConnection(ABC): """ Interface for a mailbox connection """ - def create_folder(self, folder_name: str): + def create_folder(self, folder_name: str) -> None: raise NotImplementedError - def fetch_messages(self, - reports_folder: str, - **kwargs) -> List[str]: + def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: raise NotImplementedError - def fetch_message(self, message_id) -> str: + def fetch_message(self, message_id: str) -> Union[str, bytes]: raise NotImplementedError - def delete_message(self, message_id: str): + def delete_message(self, message_id: str) -> None: raise NotImplementedError - def move_message(self, message_id: str, folder_name: str): + def move_message(self, message_id: str, folder_name: str) -> None: raise NotImplementedError - def keepalive(self): + def keepalive(self) -> None: raise NotImplementedError - def watch(self, check_callback, check_timeout): + def watch(self, check_callback, check_timeout) -> None: raise NotImplementedError diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index e2728487..08b09924 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -16,16 +16,18 @@ import atexit import mailbox import re +from typing import List, Dict, Any, Optional, Union try: import importlib.resources as pkg_resources except ImportError: # Try backported to PY<37 `importlib_resources` - import importlib_resources as pkg_resources + import importlib_resources as pkg_resources # type: ignore[no-redef] from dateutil.parser import parse as parse_date import dns.reversename import dns.resolver import dns.exception +from expiringdict import ExpiringDict import geoip2.database import geoip2.errors import publicsuffixlist @@ -59,7 +61,7 @@ class DownloadError(RuntimeError): """Raised when an error occurs when downloading a file""" -def decode_base64(data): +def decode_base64(data: str) -> bytes: """ Decodes a base64 string, with padding being optional @@ -67,17 +69,17 @@ def decode_base64(data): data: A base64 encoded string Returns: - bytes: The decoded bytes + The decoded bytes """ - data = bytes(data, encoding="ascii") - missing_padding = len(data) % 4 + data_bytes = bytes(data, encoding="ascii") + missing_padding = len(data_bytes) % 4 if missing_padding != 0: - data += b'=' * (4 - missing_padding) - return base64.b64decode(data) + data_bytes += b'=' * (4 - missing_padding) + return base64.b64decode(data_bytes) -def get_base_domain(domain): +def get_base_domain(domain: str) -> str: """ Gets the base domain name for the given domain @@ -86,30 +88,36 @@ def get_base_domain(domain): https://publicsuffix.org/list/public_suffix_list.dat. Args: - domain (str): A domain or subdomain + domain: A domain or subdomain Returns: - str: The base domain of the given domain + The base domain of the given domain """ psl = publicsuffixlist.PublicSuffixList() return psl.privatesuffix(domain) -def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0): +def query_dns( + domain: str, + record_type: str, + cache: Optional[ExpiringDict] = None, + nameservers: Optional[List[str]] = None, + timeout: float = 2.0, +) -> List[str]: """ Queries DNS Args: - domain (str): The domain or subdomain to query about - record_type (str): The record type to query for - cache (ExpiringDict): Cache storage - nameservers (list): A list of one or more nameservers to use + domain: The domain or subdomain to query about + record_type: The record type to query for + cache: Cache storage + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - timeout (float): Sets the DNS timeout in seconds + timeout: Sets the DNS timeout in seconds Returns: - list: A list of answers + A list of answers """ domain = str(domain).lower() record_type = record_type.upper() @@ -146,23 +154,28 @@ def query_dns(domain, record_type, cache=None, nameservers=None, timeout=2.0): return records -def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0): +def get_reverse_dns( + ip_address: str, + cache: Optional[ExpiringDict] = None, + nameservers: Optional[List[str]] = None, + timeout: float = 2.0, +) -> Optional[str]: """ Resolves an IP address to a hostname using a reverse DNS query Args: - ip_address (str): The IP address to resolve - cache (ExpiringDict): Cache storage - nameservers (list): A list of one or more nameservers to use + ip_address: The IP address to resolve + cache: Cache storage + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - timeout (float): Sets the DNS query timeout in seconds + timeout: Sets the DNS query timeout in seconds Returns: - str: The reverse DNS hostname (if any) + The reverse DNS hostname (if any) """ - hostname = None + hostname: Optional[str] = None try: - address = dns.reversename.from_address(ip_address) + address = str(dns.reversename.from_address(ip_address)) hostname = query_dns(address, "PTR", cache=cache, nameservers=nameservers, timeout=timeout)[0] @@ -173,20 +186,20 @@ def get_reverse_dns(ip_address, cache=None, nameservers=None, timeout=2.0): return hostname -def timestamp_to_datetime(timestamp): +def timestamp_to_datetime(timestamp: int) -> datetime: """ Converts a UNIX/DMARC timestamp to a Python ``datetime`` object Args: - timestamp (int): The timestamp + timestamp: The timestamp Returns: - datetime: The converted timestamp as a Python ``datetime`` object + The converted timestamp as a Python ``datetime`` object """ return datetime.fromtimestamp(int(timestamp)) -def timestamp_to_human(timestamp): +def timestamp_to_human(timestamp: int) -> str: """ Converts a UNIX/DMARC timestamp to a human-readable string @@ -194,21 +207,21 @@ def timestamp_to_human(timestamp): timestamp: The timestamp Returns: - str: The converted timestamp in ``YYYY-MM-DD HH:MM:SS`` format + The converted timestamp in ``YYYY-MM-DD HH:MM:SS`` format """ return timestamp_to_datetime(timestamp).strftime("%Y-%m-%d %H:%M:%S") -def human_timestamp_to_datetime(human_timestamp, to_utc=False): +def human_timestamp_to_datetime(human_timestamp: str, to_utc: bool = False) -> datetime: """ Converts a human-readable timestamp into a Python ``datetime`` object Args: - human_timestamp (str): A timestamp string - to_utc (bool): Convert the timestamp to UTC + human_timestamp: A timestamp string + to_utc: Convert the timestamp to UTC Returns: - datetime: The converted timestamp + The converted timestamp """ human_timestamp = human_timestamp.replace("-0000", "") @@ -218,31 +231,31 @@ def human_timestamp_to_datetime(human_timestamp, to_utc=False): return dt.astimezone(timezone.utc) if to_utc else dt -def human_timestamp_to_timestamp(human_timestamp): +def human_timestamp_to_timestamp(human_timestamp: str) -> float: """ Converts a human-readable timestamp into a UNIX timestamp Args: - human_timestamp (str): A timestamp in `YYYY-MM-DD HH:MM:SS`` format + human_timestamp: A timestamp in `YYYY-MM-DD HH:MM:SS`` format Returns: - float: The converted timestamp + The converted timestamp """ human_timestamp = human_timestamp.replace("T", " ") return human_timestamp_to_datetime(human_timestamp).timestamp() -def get_ip_address_country(ip_address, db_path=None): +def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Optional[str]: """ Returns the ISO code for the country associated with the given IPv4 or IPv6 address Args: - ip_address (str): The IP address to query for - db_path (str): Path to a MMDB file from MaxMind or DBIP + ip_address: The IP address to query for + db_path: Path to a MMDB file from MaxMind or DBIP Returns: - str: And ISO country code associated with the given IP address + And ISO country code associated with the given IP address """ db_paths = [ "GeoLite2-Country.mmdb", @@ -274,7 +287,7 @@ def get_ip_address_country(ip_address, db_path=None): if db_path is None: with pkg_resources.path(parsedmarc.resources.dbip, "dbip-country-lite.mmdb") as path: - db_path = path + db_path = str(path) db_age = datetime.now() - datetime.fromtimestamp( os.stat(db_path).st_mtime) @@ -293,23 +306,30 @@ def get_ip_address_country(ip_address, db_path=None): return country -def get_ip_address_info(ip_address, ip_db_path=None, cache=None, offline=False, - nameservers=None, timeout=2.0, parallel=False): +def get_ip_address_info( + ip_address: str, + ip_db_path: Optional[str] = None, + cache: Optional[ExpiringDict] = None, + offline: bool = False, + nameservers: Optional[List[str]] = None, + timeout: float = 2.0, + parallel: bool = False +) -> OrderedDict: """ Returns reverse DNS and country information for the given IP address Args: - ip_address (str): The IP address to check - ip_db_path (str): path to a MMDB file from MaxMind or DBIP - cache (ExpiringDict): Cache storage - offline (bool): Do not make online queries for geolocation or DNS - nameservers (list): A list of one or more nameservers to use + ip_address: The IP address to check + ip_db_path: path to a MMDB file from MaxMind or DBIP + cache: Cache storage + offline: Do not make online queries for geolocation or DNS + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) - timeout (float): Sets the DNS timeout in seconds - parallel (bool): parallel processing + timeout: Sets the DNS timeout in seconds + parallel: parallel processing (not used) Returns: - OrderedDict: ``ip_address``, ``reverse_dns`` + Dictionary of (`ip_address`, `country`, `reverse_dns`, `base_domain`) """ ip_address = ip_address.lower() @@ -336,7 +356,7 @@ def get_ip_address_info(ip_address, ip_db_path=None, cache=None, offline=False, return info -def parse_email_address(original_address): +def parse_email_address(original_address: str) -> OrderedDict: if original_address[0] == "": display_name = None else: @@ -355,15 +375,15 @@ def parse_email_address(original_address): ("domain", domain)]) -def get_filename_safe_string(string): +def get_filename_safe_string(string: str) -> str: """ Converts a string to a string that is safe for a filename Args: - string (str): A string to make safe for a filename + string: A string to make safe for a filename Returns: - str: A string safe for a filename + A string safe for a filename """ invalid_filename_chars = ['\\', '/', ':', '"', '*', '?', '|', '\n', '\r'] @@ -378,15 +398,15 @@ def get_filename_safe_string(string): return string -def is_mbox(path): +def is_mbox(path: str) -> bool: """ Checks if the given content is an MBOX mailbox file Args: - path: Content to check + Content to check Returns: - bool: A flag that indicates if the file is an MBOX mailbox file + If the file is an MBOX mailbox file """ _is_mbox = False try: @@ -399,7 +419,7 @@ def is_mbox(path): return _is_mbox -def is_outlook_msg(content): +def is_outlook_msg(content: Any) -> bool: """ Checks if the given content is an Outlook msg OLE/MSG file @@ -407,19 +427,19 @@ def is_outlook_msg(content): content: Content to check Returns: - bool: A flag that indicates if the file is an Outlook MSG file + If the file is an Outlook MSG file """ return isinstance(content, bytes) and content.startswith( b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1") -def convert_outlook_msg(msg_bytes): +def convert_outlook_msg(msg_bytes: bytes) -> bytes: """ Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to standard RFC 822 format Args: - msg_bytes (bytes): the content of the .msg file + msg_bytes: the content of the .msg file Returns: A RFC 822 string @@ -447,16 +467,16 @@ def convert_outlook_msg(msg_bytes): return rfc822 -def parse_email(data, strip_attachment_payloads=False): +def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False) -> Dict[str, Any]: """ A simplified email parser Args: data: The RFC 822 message string, or MSG binary - strip_attachment_payloads (bool): Remove attachment payloads + strip_attachment_payloads: Remove attachment payloads Returns: - dict: Parsed email data + Parsed email data """ if isinstance(data, bytes): diff --git a/pyproject.toml b/pyproject.toml index d6a098ff..261c6467 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python :: 3" ] + +requires-python = ">=3.7" + dependencies = [ "azure-identity>=1.8.0", "azure-monitor-ingestion>=1.0.0", From 5e741a3bf535d100d3d1e4e108f3d4fd1ad8fe97 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Wed, 3 Jan 2024 01:44:30 +1100 Subject: [PATCH 03/15] Run black formatter, refactor LogAnayticsClient, update doc strings --- parsedmarc/__init__.py | 644 +++++++++++---------- parsedmarc/cli.py | 801 +++++++++++++------------- parsedmarc/elastic.py | 270 ++++----- parsedmarc/kafkaclient.py | 67 +-- parsedmarc/loganalytics.py | 185 ++---- parsedmarc/mail/__init__.py | 5 +- parsedmarc/mail/gmail.py | 76 ++- parsedmarc/mail/graph.py | 207 ++++--- parsedmarc/mail/imap.py | 62 +- parsedmarc/mail/mailbox_connection.py | 5 +- parsedmarc/s3.py | 34 +- parsedmarc/splunk.py | 48 +- parsedmarc/syslog.py | 8 +- parsedmarc/utils.py | 164 +++--- 14 files changed, 1241 insertions(+), 1335 deletions(-) diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index af26423d..48c3b74b 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -72,11 +72,9 @@ def _parse_report_record( offline: bool = False, nameservers: Optional[List[str]] = None, dns_timeout: float = 2.0, - parallel: bool = False + parallel: bool = False, ) -> OrderedDict: - """ - Converts a record from a DMARC aggregate report into a more consistent - format + """Convert a record from a DMARC aggregate report into a more consistent format Args: record: The record to convert @@ -91,21 +89,26 @@ def _parse_report_record( """ record = record.copy() new_record: OrderedDict[str, Any] = OrderedDict() - new_record_source = get_ip_address_info(record["row"]["source_ip"], - cache=IP_ADDRESS_CACHE, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - timeout=dns_timeout, - parallel=parallel) + new_record_source = get_ip_address_info( + record["row"]["source_ip"], + cache=IP_ADDRESS_CACHE, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + timeout=dns_timeout, + parallel=parallel, + ) new_record["source"] = new_record_source new_record["count"] = int(record["row"]["count"]) policy_evaluated = record["row"]["policy_evaluated"].copy() - new_policy_evaluated = OrderedDict([("disposition", "none"), - ("dkim", "fail"), - ("spf", "fail"), - ("policy_override_reasons", []) - ]) + new_policy_evaluated = OrderedDict( + [ + ("disposition", "none"), + ("dkim", "fail"), + ("spf", "fail"), + ("policy_override_reasons", []), + ] + ) if "disposition" in policy_evaluated: new_policy_evaluated["disposition"] = policy_evaluated["disposition"] if cast(str, new_policy_evaluated["disposition"]).strip().lower() == "pass": @@ -115,10 +118,10 @@ def _parse_report_record( if "spf" in policy_evaluated: new_policy_evaluated["spf"] = policy_evaluated["spf"] reasons = [] - spf_aligned = policy_evaluated["spf"] is not None and policy_evaluated[ - "spf"].lower() == "pass" - dkim_aligned = policy_evaluated["dkim"] is not None and policy_evaluated[ - "dkim"].lower() == "pass" + spf_aligned = policy_evaluated["spf"] is not None and policy_evaluated["spf"].lower() == "pass" + dkim_aligned = ( + policy_evaluated["dkim"] is not None and policy_evaluated["dkim"].lower() == "pass" + ) dmarc_aligned = spf_aligned or dkim_aligned new_record["alignment"] = dict() new_record["alignment"]["spf"] = spf_aligned @@ -142,7 +145,7 @@ def _parse_report_record( if type(new_record["identifiers"]["header_from"]) is str: lowered_from = new_record["identifiers"]["header_from"].lower() else: - lowered_from = '' + lowered_from = "" new_record["identifiers"]["header_from"] = lowered_from if record["auth_results"] is not None: auth_results = record["auth_results"].copy() @@ -217,7 +220,7 @@ def parse_aggregate_report_xml( nameservers: Optional[List[str]] = None, timeout: float = 2.0, parallel: bool = False, - keep_alive: Optional[Callable] = None + keep_alive: Optional[Callable] = None, ) -> OrderedDict[str, Any]: """Parses a DMARC XML report string and returns a consistent OrderedDict @@ -225,8 +228,7 @@ def parse_aggregate_report_xml( xml: A string of DMARC aggregate report XML ip_db_path: Path to a MMDB file from MaxMind or DBIP offline: Do not query online for geolocation or DNS - nameservers: A list of one or more nameservers to use - (Cloudflare's public DNS resolvers by default) + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) timeout: Sets the DNS timeout in seconds parallel: Parallel processing keep_alive: Keep alive function @@ -242,19 +244,19 @@ def parse_aggregate_report_xml( errors.append("Invalid XML: {0}".format(e.__str__())) try: tree = etree.parse( - BytesIO(xml.encode('utf-8')), - etree.XMLParser(recover=True, resolve_entities=False)) + BytesIO(xml.encode("utf-8")), etree.XMLParser(recover=True, resolve_entities=False) + ) s = etree.tostring(tree) - xml = '' if s is None else s.decode('utf-8') + xml = "" if s is None else s.decode("utf-8") except Exception: - xml = u'' + xml = "" try: # Replace XML header (sometimes they are invalid) - xml = xml_header_regex.sub("", xml) + xml = xml_header_regex.sub('', xml) # Remove invalid schema tags - xml = xml_schema_regex.sub('', xml) + xml = xml_schema_regex.sub("", xml) report = xmltodict.parse(xml)["feedback"] report_metadata = report["report_metadata"] @@ -265,20 +267,19 @@ def parse_aggregate_report_xml( new_report_metadata: OrderedDict[str, Any] = OrderedDict() if report_metadata["org_name"] is None: if report_metadata["email"] is not None: - report_metadata["org_name"] = report_metadata[ - "email"].split("@")[-1] + report_metadata["org_name"] = report_metadata["email"].split("@")[-1] org_name = report_metadata["org_name"] if org_name is not None and " " not in org_name: new_org_name = get_base_domain(org_name) if new_org_name is not None: org_name = new_org_name if not org_name: - logger.debug("Could not parse org_name from XML.\r\n{0}".format( - report.__str__() - )) - raise KeyError("Organization name is missing. \ + logger.debug("Could not parse org_name from XML.\r\n{0}".format(report.__str__())) + raise KeyError( + "Organization name is missing. \ This field is a requirement for \ - saving the report") + saving the report" + ) new_report_metadata["org_name"] = org_name new_report_metadata["org_email"] = report_metadata["email"] extra = None @@ -287,11 +288,10 @@ def parse_aggregate_report_xml( new_report_metadata["org_extra_contact_info"] = extra new_report_metadata["report_id"] = report_metadata["report_id"] report_id = new_report_metadata["report_id"] - report_id = report_id.replace("<", - "").replace(">", "").split("@")[0] + report_id = report_id.replace("<", "").replace(">", "").split("@")[0] new_report_metadata["report_id"] = report_id date_range = report["report_metadata"]["date_range"] - if int(date_range["end"]) - int(date_range["begin"]) > 2*86400: + if int(date_range["end"]) - int(date_range["begin"]) > 2 * 86400: _error = "Timespan > 24 hours - RFC 7489 section 7.2" errors.append(_error) date_range["begin"] = timestamp_to_human(date_range["begin"]) @@ -342,23 +342,26 @@ def parse_aggregate_report_xml( if keep_alive is not None and i > 0 and i % 20 == 0: logger.debug("Sending keepalive cmd") keep_alive() - logger.debug("Processed {0}/{1}".format( - i, len(report["record"]))) - report_record = _parse_report_record(report["record"][i], - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - dns_timeout=timeout, - parallel=parallel) + logger.debug("Processed {0}/{1}".format(i, len(report["record"]))) + report_record = _parse_report_record( + report["record"][i], + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + dns_timeout=timeout, + parallel=parallel, + ) records.append(report_record) else: - report_record = _parse_report_record(report["record"], - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - dns_timeout=timeout, - parallel=parallel) + report_record = _parse_report_record( + report["record"], + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + dns_timeout=timeout, + parallel=parallel, + ) records.append(report_record) new_report["records"] = records @@ -366,24 +369,19 @@ def parse_aggregate_report_xml( return new_report except expat.ExpatError as error: - raise InvalidAggregateReport( - "Invalid XML: {0}".format(error.__str__())) + raise InvalidAggregateReport("Invalid XML: {0}".format(error.__str__())) except KeyError as error: - raise InvalidAggregateReport( - "Missing field: {0}".format(error.__str__())) + raise InvalidAggregateReport("Missing field: {0}".format(error.__str__())) except AttributeError: raise InvalidAggregateReport("Report missing required section") except Exception as error: - raise InvalidAggregateReport( - "Unexpected error: {0}".format(error.__str__())) + raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__())) def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: - """ - Extracts xml from a zip or gzip file at the given path, file-like object, - or bytes. + """Extracts xml from a zip or gzip file at the given path, file-like object, or bytes. Args: input_: A path to a file, a file like object, or bytes @@ -406,13 +404,13 @@ def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: if header.startswith(MAGIC_ZIP): _zip = zipfile.ZipFile(file_object) - xml = _zip.open(_zip.namelist()[0]).read().decode(errors='ignore') + xml = _zip.open(_zip.namelist()[0]).read().decode(errors="ignore") elif header.startswith(MAGIC_GZIP): - xml = zlib.decompress(file_object.read(), zlib.MAX_WBITS | 16).decode(errors='ignore') + xml = zlib.decompress(file_object.read(), zlib.MAX_WBITS | 16).decode(errors="ignore") elif header.startswith(MAGIC_XML): - xml = file_object.read().decode(errors='ignore') + xml = file_object.read().decode(errors="ignore") else: file_object.close() @@ -424,27 +422,24 @@ def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: raise InvalidAggregateReport("File was not found") except UnicodeDecodeError: file_object.close() - raise InvalidAggregateReport("File objects must be opened in binary " - "(rb) mode") + raise InvalidAggregateReport("File objects must be opened in binary " "(rb) mode") except Exception as error: file_object.close() - raise InvalidAggregateReport( - "Invalid archive file: {0}".format(error.__str__())) + raise InvalidAggregateReport("Invalid archive file: {0}".format(error.__str__())) return xml def parse_aggregate_report_file( _input: Union[bytes, str, BinaryIO], - offline: bool =False, + offline: bool = False, ip_db_path: Optional[str] = None, nameservers: Optional[List[str]] = None, dns_timeout: float = 2.0, parallel: bool = False, - keep_alive: Optional[Callable] = None + keep_alive: Optional[Callable] = None, ) -> OrderedDict[str, Any]: - """Parses a file at the given path, a file-like object. or bytes as an - aggregate DMARC report + """Parse a file at the given path, a file-like object. or bytes as an aggregate DMARC report Args: _input: A path to a file, a file like object, or bytes @@ -472,17 +467,16 @@ def parse_aggregate_report_file( ) -def parsed_aggregate_reports_to_csv_rows(reports: Union[OrderedDict, List[OrderedDict]]) -> List[Dict[str, Union[str, int, bool]]]: - """ - Converts one or more parsed aggregate reports to list of dicts in flat CSV - format +def parsed_aggregate_reports_to_csv_rows( + reports: Union[OrderedDict, List[OrderedDict]] +) -> List[Dict[str, Union[str, int, bool]]]: + """Convert one or more parsed aggregate reports to list of dicts in flat CSV format Args: reports: A parsed aggregate report or list of parsed aggregate reports Returns: - Parsed aggregate report data as a list of dicts in flat CSV - format + Parsed aggregate report data as a list of dicts in flat CSV format """ def to_str(obj): @@ -510,12 +504,23 @@ def to_str(obj): pct = report["policy_published"]["pct"] fo = report["policy_published"]["fo"] - report_dict = dict(xml_schema=xml_schema, org_name=org_name, - org_email=org_email, - org_extra_contact_info=org_extra_contact, - report_id=report_id, begin_date=begin_date, - end_date=end_date, errors=errors, domain=domain, - adkim=adkim, aspf=aspf, p=p, sp=sp, pct=pct, fo=fo) + report_dict = dict( + xml_schema=xml_schema, + org_name=org_name, + org_email=org_email, + org_extra_contact_info=org_extra_contact, + report_id=report_id, + begin_date=begin_date, + end_date=end_date, + errors=errors, + domain=domain, + adkim=adkim, + aspf=aspf, + p=p, + sp=sp, + pct=pct, + fo=fo, + ) for record in report["records"]: row = report_dict.copy() @@ -528,18 +533,17 @@ def to_str(obj): row["dkim_aligned"] = record["alignment"]["dkim"] row["dmarc_aligned"] = record["alignment"]["dmarc"] row["disposition"] = record["policy_evaluated"]["disposition"] - policy_override_reasons = list(map( - lambda r_: r_["type"], - record["policy_evaluated"] - ["policy_override_reasons"])) - policy_override_comments = list(map( - lambda r_: r_["comment"] or "none", - record["policy_evaluated"] - ["policy_override_reasons"])) - row["policy_override_reasons"] = ",".join( - policy_override_reasons) - row["policy_override_comments"] = "|".join( - policy_override_comments) + policy_override_reasons = list( + map(lambda r_: r_["type"], record["policy_evaluated"]["policy_override_reasons"]) + ) + policy_override_comments = list( + map( + lambda r_: r_["comment"] or "none", + record["policy_evaluated"]["policy_override_reasons"], + ) + ) + row["policy_override_reasons"] = ",".join(policy_override_reasons) + row["policy_override_comments"] = "|".join(policy_override_comments) row["envelope_from"] = record["identifiers"]["envelope_from"] row["header_from"] = record["identifiers"]["header_from"] envelope_to = record["identifiers"]["envelope_to"] @@ -570,15 +574,13 @@ def to_str(obj): for r in rows: for k, v in r.items(): if type(v) not in [str, int, bool]: - r[k] = '' + r[k] = "" return rows def parsed_aggregate_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict]]) -> str: - """ - Converts one or more parsed aggregate reports to flat CSV format, including - headers + """Convert one or more parsed aggregate reports to flat CSV format, including headers Args: reports: A parsed aggregate report or list of parsed aggregate reports @@ -587,16 +589,43 @@ def parsed_aggregate_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict Parsed aggregate report data in flat CSV format, including headers """ - fields = ["xml_schema", "org_name", "org_email", - "org_extra_contact_info", "report_id", "begin_date", "end_date", - "errors", "domain", "adkim", "aspf", "p", "sp", "pct", "fo", - "source_ip_address", "source_country", "source_reverse_dns", - "source_base_domain", "count", "spf_aligned", - "dkim_aligned", "dmarc_aligned", "disposition", - "policy_override_reasons", "policy_override_comments", - "envelope_from", "header_from", - "envelope_to", "dkim_domains", "dkim_selectors", "dkim_results", - "spf_domains", "spf_scopes", "spf_results"] + fields = [ + "xml_schema", + "org_name", + "org_email", + "org_extra_contact_info", + "report_id", + "begin_date", + "end_date", + "errors", + "domain", + "adkim", + "aspf", + "p", + "sp", + "pct", + "fo", + "source_ip_address", + "source_country", + "source_reverse_dns", + "source_base_domain", + "count", + "spf_aligned", + "dkim_aligned", + "dmarc_aligned", + "disposition", + "policy_override_reasons", + "policy_override_comments", + "envelope_from", + "header_from", + "envelope_to", + "dkim_domains", + "dkim_selectors", + "dkim_results", + "spf_domains", + "spf_scopes", + "spf_results", + ] csv_file_object = StringIO(newline="\n") writer = DictWriter(csv_file_object, fields) @@ -620,10 +649,9 @@ def parse_forensic_report( nameservers: Optional[List[str]] = None, dns_timeout: float = 2.0, strip_attachment_payloads: bool = False, - parallel: bool = False + parallel: bool = False, ) -> OrderedDict: - """ - Converts a DMARC forensic report and sample to a ``OrderedDict`` + """Converts a DMARC forensic report and sample to a ``OrderedDict`` Args: feedback_report: A message's feedback report as a string @@ -652,8 +680,7 @@ def parse_forensic_report( if "arrival_date" not in parsed_report: if msg_date is None: - raise InvalidForensicReport( - "Forensic sample is not a valid email") + raise InvalidForensicReport("Forensic sample is not a valid email") parsed_report["arrival_date"] = msg_date.isoformat() if "version" not in parsed_report: @@ -673,15 +700,18 @@ def parse_forensic_report( parsed_report["delivery_result"] = "other" parsed_report["arrival_date_utc"] = human_timestamp_to_datetime( - parsed_report["arrival_date"], to_utc=True).strftime("%Y-%m-%d %H:%M:%S") - - ip_address = re.split(r'\s', parsed_report["source_ip"]).pop(0) - parsed_report_source = get_ip_address_info(ip_address, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - timeout=dns_timeout, - parallel=parallel) + parsed_report["arrival_date"], to_utc=True + ).strftime("%Y-%m-%d %H:%M:%S") + + ip_address = re.split(r"\s", parsed_report["source_ip"]).pop(0) + parsed_report_source = get_ip_address_info( + ip_address, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + timeout=dns_timeout, + parallel=parallel, + ) parsed_report["source"] = parsed_report_source del parsed_report["source_ip"] @@ -701,15 +731,17 @@ def parse_forensic_report( auth_failure = parsed_report["auth_failure"].split(",") parsed_report["auth_failure"] = auth_failure - optional_fields = ["original_envelope_id", "dkim_domain", - "original_mail_from", "original_rcpt_to"] + optional_fields = [ + "original_envelope_id", + "dkim_domain", + "original_mail_from", + "original_rcpt_to", + ] for optional_field in optional_fields: if optional_field not in parsed_report: parsed_report[optional_field] = None - parsed_sample = parse_email( - sample, - strip_attachment_payloads=strip_attachment_payloads) + parsed_sample = parse_email(sample, strip_attachment_payloads=strip_attachment_payloads) if "reported_domain" not in parsed_report: parsed_report["reported_domain"] = parsed_sample["from"]["domain"] @@ -729,18 +761,16 @@ def parse_forensic_report( return parsed_report except KeyError as error: - raise InvalidForensicReport("Missing value: {0}".format( - error.__str__())) + raise InvalidForensicReport("Missing value: {0}".format(error.__str__())) except Exception as error: - raise InvalidForensicReport( - "Unexpected error: {0}".format(error.__str__())) + raise InvalidForensicReport("Unexpected error: {0}".format(error.__str__())) -def parsed_forensic_reports_to_csv_rows(reports: Union[OrderedDict, List[OrderedDict]]) -> List[Dict[str, Any]]: - """ - Converts one or more parsed forensic reports to a list of dicts in flat CSV - format +def parsed_forensic_reports_to_csv_rows( + reports: Union[OrderedDict, List[OrderedDict]] +) -> List[Dict[str, Any]]: + """Convert one or more parsed forensic reports to a list of dicts in flat CSV format Args: reports: A parsed forensic report or list of parsed forensic reports @@ -763,8 +793,7 @@ def parsed_forensic_reports_to_csv_rows(reports: Union[OrderedDict, List[Ordered row["subject"] = report["parsed_sample"]["subject"] row["auth_failure"] = ",".join(report["auth_failure"]) authentication_mechanisms = report["authentication_mechanisms"] - row["authentication_mechanisms"] = ",".join( - authentication_mechanisms) + row["authentication_mechanisms"] = ",".join(authentication_mechanisms) del row["sample"] del row["parsed_sample"] rows.append(row) @@ -773,9 +802,7 @@ def parsed_forensic_reports_to_csv_rows(reports: Union[OrderedDict, List[Ordered def parsed_forensic_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict]]) -> str: - """ - Converts one or more parsed forensic reports to flat CSV format, including - headers + """Convert one or more parsed forensic reports to flat CSV format, including headers Args: reports: A parsed forensic report or list of parsed forensic reports @@ -783,13 +810,29 @@ def parsed_forensic_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict] Returns: Parsed forensic report data in flat CSV format, including headers """ - fields = ["feedback_type", "user_agent", "version", "original_envelope_id", - "original_mail_from", "original_rcpt_to", "arrival_date", - "arrival_date_utc", "subject", "message_id", - "authentication_results", "dkim_domain", "source_ip_address", - "source_country", "source_reverse_dns", "source_base_domain", - "delivery_result", "auth_failure", "reported_domain", - "authentication_mechanisms", "sample_headers_only"] + fields = [ + "feedback_type", + "user_agent", + "version", + "original_envelope_id", + "original_mail_from", + "original_rcpt_to", + "arrival_date", + "arrival_date_utc", + "subject", + "message_id", + "authentication_results", + "dkim_domain", + "source_ip_address", + "source_country", + "source_reverse_dns", + "source_base_domain", + "delivery_result", + "auth_failure", + "reported_domain", + "authentication_mechanisms", + "sample_headers_only", + ] csv_file = StringIO() csv_writer = DictWriter(csv_file, fieldnames=fields) @@ -816,8 +859,7 @@ def parse_report_email( parallel: bool = False, keep_alive: Optional[Callable] = None, ) -> OrderedDict[str, Union[str, OrderedDict]]: - """ - Parses a DMARC report from an email + """Parse a DMARC report from an email Args: input_: An emailed DMARC report in RFC 822 format, as bytes or a string @@ -867,8 +909,7 @@ def parse_report_email( feedback_report = payload else: feedback_report = b64decode(payload).__str__() - feedback_report = feedback_report.lstrip( - "b'").rstrip("'") + feedback_report = feedback_report.lstrip("b'").rstrip("'") feedback_report = feedback_report.replace("\\r", "") feedback_report = feedback_report.replace("\\n", "\n") except (ValueError, TypeError, binascii.Error): @@ -886,19 +927,22 @@ def parse_report_email( for match in field_matches: field_name = match[0].lower().replace(" ", "-") fields[field_name] = match[1].strip() - feedback_report = "Arrival-Date: {}\n" \ - "Source-IP: {}" \ - "".format(fields["received-date"], - fields["sender-ip-address"]) + feedback_report = ( + "Arrival-Date: {}\n" + "Source-IP: {}" + "".format(fields["received-date"], fields["sender-ip-address"]) + ) sample = parts[1].lstrip() sample = sample.replace("=\r\n", "") logger.debug(sample) else: try: payload = b64decode(payload) - if payload.startswith(MAGIC_ZIP) or \ - payload.startswith(MAGIC_GZIP) or \ - payload.startswith(MAGIC_XML): + if ( + payload.startswith(MAGIC_ZIP) + or payload.startswith(MAGIC_GZIP) + or payload.startswith(MAGIC_XML) + ): ns = nameservers aggregate_report = parse_aggregate_report_file( payload, @@ -907,21 +951,23 @@ def parse_report_email( nameservers=ns, dns_timeout=dns_timeout, parallel=parallel, - keep_alive=keep_alive) + keep_alive=keep_alive, + ) return OrderedDict([("report_type", "aggregate"), ("report", aggregate_report)]) except (TypeError, ValueError, binascii.Error): pass except InvalidAggregateReport as e: - error = 'Message with subject "{0}" ' \ - 'is not a valid ' \ - 'aggregate DMARC report: {1}'.format(subject, e) + error = ( + 'Message with subject "{0}" ' + "is not a valid " + "aggregate DMARC report: {1}".format(subject, e) + ) raise InvalidAggregateReport(error) except Exception as e: - error = 'Unable to parse message with ' \ - 'subject "{0}": {1}'.format(subject, e) + error = "Unable to parse message with " 'subject "{0}": {1}'.format(subject, e) raise InvalidDMARCReport(error) if feedback_report and sample: @@ -934,11 +980,14 @@ def parse_report_email( nameservers=nameservers, dns_timeout=dns_timeout, strip_attachment_payloads=strip_attachment_payloads, - parallel=parallel) + parallel=parallel, + ) except InvalidForensicReport as e: - error = 'Message with subject "{0}" ' \ - 'is not a valid ' \ - 'forensic DMARC report: {1}'.format(subject, e) + error = ( + 'Message with subject "{0}" ' + "is not a valid " + "forensic DMARC report: {1}".format(subject, e) + ) raise InvalidForensicReport(error) except Exception as e: raise InvalidForensicReport(e.__str__()) @@ -957,18 +1006,15 @@ def parse_report_file( ip_db_path: Optional[str] = None, offline: bool = False, parallel: bool = False, - keep_alive: Optional[Callable] = None + keep_alive: Optional[Callable] = None, ) -> OrderedDict: - """Parses a DMARC aggregate or forensic file at the given path, a - file-like object. or bytes + """Parse a DMARC aggregate or forensic file at the given path, a file-like object. or bytes Args: input_: A path to a file, a file like object, or bytes - nameservers: A list of one or more nameservers to use - (Cloudflare's public DNS resolvers by default) + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) dns_timeout: Sets the DNS timeout in seconds - strip_attachment_payloads: Remove attachment payloads from - forensic report results + strip_attachment_payloads: Remove attachment payloads from forensic report results ip_db_path: Path to a MMDB file from MaxMind or DBIP offline: Do not make online queries for geolocation or DNS parallel: Parallel processing @@ -1000,7 +1046,7 @@ def parse_report_file( parallel=parallel, keep_alive=keep_alive, ) - results = OrderedDict([("report_type", "aggregate"),("report", report)]) + results = OrderedDict([("report_type", "aggregate"), ("report", report)]) except InvalidAggregateReport: try: @@ -1015,8 +1061,7 @@ def parse_report_file( keep_alive=keep_alive, ) except InvalidDMARCReport: - raise InvalidDMARCReport("Not a valid aggregate or forensic " - "report") + raise InvalidDMARCReport("Not a valid aggregate or forensic " "report") return results @@ -1027,18 +1072,15 @@ def get_dmarc_reports_from_mbox( strip_attachment_payloads: bool = False, ip_db_path: Optional[str] = None, offline: bool = False, - parallel: bool = False + parallel: bool = False, ) -> OrderedDict[str, List[OrderedDict]]: - """Parses a mailbox in mbox format containing e-mails with attached - DMARC reports + """Parses a mailbox in mbox format containing e-mails with attached DMARC reports Args: input_: A path to a mbox file - nameservers: A list of one or more nameservers to use - (Cloudflare's public DNS resolvers by default) + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) dns_timeout: Sets the DNS timeout in seconds - strip_attachment_payloads: Remove attachment payloads from - forensic report results + strip_attachment_payloads: Remove attachment payloads from forensic report results ip_db_path: Path to a MMDB file from MaxMind or DBIP offline: Do not make online queries for geolocation or DNS parallel: Parallel processing @@ -1053,23 +1095,22 @@ def get_dmarc_reports_from_mbox( mbox = mailbox.mbox(input_) message_keys = mbox.keys() total_messages = len(message_keys) - logger.debug("Found {0} messages in {1}".format(total_messages, - input_)) + logger.debug("Found {0} messages in {1}".format(total_messages, input_)) for i in range(len(message_keys)): message_key = message_keys[i] - logger.info("Processing message {0} of {1}".format( - i+1, total_messages - )) + logger.info("Processing message {0} of {1}".format(i + 1, total_messages)) msg_content = mbox.get_string(message_key) try: sa = strip_attachment_payloads - parsed_email = parse_report_email(msg_content, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - dns_timeout=dns_timeout, - strip_attachment_payloads=sa, - parallel=parallel) + parsed_email = parse_report_email( + msg_content, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + dns_timeout=dns_timeout, + strip_attachment_payloads=sa, + parallel=parallel, + ) if parsed_email["report_type"] == "aggregate": aggregate_reports.append(cast(OrderedDict[Any, Any], parsed_email["report"])) elif parsed_email["report_type"] == "forensic": @@ -1078,8 +1119,9 @@ def get_dmarc_reports_from_mbox( logger.warning(error.__str__()) except mailbox.NoSuchMailboxError: raise InvalidDMARCReport("Mailbox {0} does not exist".format(input_)) - return OrderedDict([("aggregate_reports", aggregate_reports), - ("forensic_reports", forensic_reports)]) + return OrderedDict( + [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] + ) def get_dmarc_reports_from_mailbox( @@ -1097,8 +1139,7 @@ def get_dmarc_reports_from_mailbox( batch_size: int = 10, create_folders: bool = True, ) -> OrderedDict[str, List[OrderedDict]]: - """ - Fetches and parses DMARC reports from a mailbox + """Fetches and parses DMARC reports from a mailbox Args: connection: A Mailbox connection object @@ -1110,13 +1151,10 @@ def get_dmarc_reports_from_mailbox( offline: Do not query online for geolocation or DNS nameservers: A list of DNS nameservers to query dns_timeout: Set the DNS query timeout - strip_attachment_payloads: Remove attachment payloads from - forensic report results + strip_attachment_payloads: Remove attachment payloads from forensic report results results: Results from the previous run - batch_size: Number of messages to read and process before saving - (use 0 for no limit) - create_folders: Whether to create the destination folders - (not used in watch) + batch_size: Number of messages to read and process before saving (use 0 for no limit) + create_folders: Whether to create the destination folders (not used in watch) Returns: OrderedDict: Lists of ``aggregate_reports`` and ``forensic_reports`` @@ -1147,8 +1185,7 @@ def get_dmarc_reports_from_mailbox( messages = connection.fetch_messages(reports_folder, batch_size=batch_size) total_messages = len(messages) - logger.debug("Found {0} messages in {1}".format(len(messages), - reports_folder)) + logger.debug("Found {0} messages in {1}".format(len(messages), reports_folder)) if batch_size: message_limit = min(total_messages, batch_size) @@ -1159,19 +1196,19 @@ def get_dmarc_reports_from_mailbox( for i in range(message_limit): msg_uid = messages[i] - logger.debug("Processing message {0} of {1}: UID {2}".format( - i+1, message_limit, msg_uid - )) + logger.debug("Processing message {0} of {1}: UID {2}".format(i + 1, message_limit, msg_uid)) msg_content = connection.fetch_message(msg_uid) try: sa = strip_attachment_payloads - parsed_email = parse_report_email(msg_content, - nameservers=nameservers, - dns_timeout=dns_timeout, - ip_db_path=ip_db_path, - offline=offline, - strip_attachment_payloads=sa, - keep_alive=connection.keepalive) + parsed_email = parse_report_email( + msg_content, + nameservers=nameservers, + dns_timeout=dns_timeout, + ip_db_path=ip_db_path, + offline=offline, + strip_attachment_payloads=sa, + keep_alive=connection.keepalive, + ) if parsed_email["report_type"] == "aggregate": aggregate_reports.append(cast(OrderedDict[Any, Any], parsed_email["report"])) aggregate_report_msg_uids.append(msg_uid) @@ -1182,73 +1219,79 @@ def get_dmarc_reports_from_mailbox( logger.warning(error.__str__()) if not test: if delete: - logger.debug( - "Deleting message UID {0}".format(msg_uid)) + logger.debug("Deleting message UID {0}".format(msg_uid)) connection.delete_message(msg_uid) else: logger.debug( - "Moving message UID {0} to {1}".format( - msg_uid, invalid_reports_folder)) + "Moving message UID {0} to {1}".format(msg_uid, invalid_reports_folder) + ) connection.move_message(msg_uid, invalid_reports_folder) if not test: if delete: - processed_messages = aggregate_report_msg_uids + \ - forensic_report_msg_uids + processed_messages = aggregate_report_msg_uids + forensic_report_msg_uids number_of_processed_msgs = len(processed_messages) for i in range(number_of_processed_msgs): msg_uid = processed_messages[i] logger.debug( "Deleting message {0} of {1}: UID {2}".format( - i + 1, number_of_processed_msgs, msg_uid)) + i + 1, number_of_processed_msgs, msg_uid + ) + ) try: connection.delete_message(msg_uid) except Exception as e: - message = "Mailbox error: Error deleting message UID {0}: {1}".format(msg_uid, repr(e)) + message = "Mailbox error: Error deleting message UID {0}: {1}".format( + msg_uid, repr(e) + ) logger.error(message) else: if len(aggregate_report_msg_uids) > 0: log_message = "Moving aggregate report messages from" logger.debug( - "{0} {1} to {2}".format( - log_message, reports_folder, - aggregate_reports_folder)) + "{0} {1} to {2}".format(log_message, reports_folder, aggregate_reports_folder) + ) number_of_agg_report_msgs = len(aggregate_report_msg_uids) for i in range(number_of_agg_report_msgs): msg_uid = aggregate_report_msg_uids[i] logger.debug( "Moving message {0} of {1}: UID {2}".format( - i+1, number_of_agg_report_msgs, msg_uid)) + i + 1, number_of_agg_report_msgs, msg_uid + ) + ) try: - connection.move_message(msg_uid, - aggregate_reports_folder) + connection.move_message(msg_uid, aggregate_reports_folder) except Exception as e: - message = "Mailbox error: Error moving message UID {0}: {1}".format(msg_uid, repr(e)) + message = "Mailbox error: Error moving message UID {0}: {1}".format( + msg_uid, repr(e) + ) logger.error(message) if len(forensic_report_msg_uids) > 0: message = "Moving forensic report messages from" logger.debug( - "{0} {1} to {2}".format(message, - reports_folder, - forensic_reports_folder)) + "{0} {1} to {2}".format(message, reports_folder, forensic_reports_folder) + ) number_of_forensic_msgs = len(forensic_report_msg_uids) for i in range(number_of_forensic_msgs): msg_uid = forensic_report_msg_uids[i] message = "Moving message" - logger.debug("{0} {1} of {2}: UID {3}".format( - message, - i + 1, number_of_forensic_msgs, msg_uid)) + logger.debug( + "{0} {1} of {2}: UID {3}".format( + message, i + 1, number_of_forensic_msgs, msg_uid + ) + ) try: - connection.move_message(msg_uid, - forensic_reports_folder) + connection.move_message(msg_uid, forensic_reports_folder) except Exception as e: message = "Mailbox error: Error moving message UID {0}: {1}".format( - msg_uid, repr(e)) + msg_uid, repr(e) + ) logger.error(message) - results = OrderedDict([("aggregate_reports", aggregate_reports), - ("forensic_reports", forensic_reports)]) + results = OrderedDict( + [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] + ) total_messages = len(connection.fetch_messages(reports_folder)) @@ -1265,7 +1308,7 @@ def get_dmarc_reports_from_mailbox( strip_attachment_payloads=strip_attachment_payloads, results=results, ip_db_path=ip_db_path, - offline=offline + offline=offline, ) return results @@ -1284,11 +1327,9 @@ def watch_inbox( nameservers: Optional[List[str]] = None, dns_timeout: float = 6.0, strip_attachment_payloads: bool = False, - batch_size: Optional[int] = None + batch_size: Optional[int] = None, ) -> None: - """ - Watches the mailbox for new messages and - sends the results to a callback function + """Watches a mailbox for new messages and sends the results to a callback function Args: mailbox_connection: The mailbox connection object @@ -1297,36 +1338,34 @@ def watch_inbox( archive_folder: The folder to move processed mail to delete: Delete messages after processing them test: Do not move or delete messages after processing them - check_timeout: Number of seconds to wait for a IMAP IDLE response - or the number of seconds until the next mail check + check_timeout: Number of seconds to wait for a IMAP IDLE response or the number of seconds until the next mail check ip_db_path: Path to a MMDB file from MaxMind or DBIP offline: Do not query online for geolocation or DNS - nameservers: A list of one or more nameservers to use - (Cloudflare's public DNS resolvers by default) + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) dns_timeout: Set the DNS query timeout - strip_attachment_payloads: Replace attachment payloads in - forensic report samples with None + strip_attachment_payloads: Replace attachment payloads in forensic report samples with None batch_size: Number of messages to read and process before saving """ def check_callback(connection): sa = strip_attachment_payloads - res = get_dmarc_reports_from_mailbox(connection=connection, - reports_folder=reports_folder, - archive_folder=archive_folder, - delete=delete, - test=test, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - dns_timeout=dns_timeout, - strip_attachment_payloads=sa, - batch_size=batch_size, - create_folders=False) + res = get_dmarc_reports_from_mailbox( + connection=connection, + reports_folder=reports_folder, + archive_folder=archive_folder, + delete=delete, + test=test, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + dns_timeout=dns_timeout, + strip_attachment_payloads=sa, + batch_size=batch_size, + create_folders=False, + ) callback(res) - mailbox_connection.watch(check_callback=check_callback, - check_timeout=check_timeout) + mailbox_connection.watch(check_callback=check_callback, check_timeout=check_timeout) def append_json(filename: str, reports: List[OrderedDict]) -> None: @@ -1372,10 +1411,9 @@ def save_output( aggregate_json_filename: str = "aggregate.json", forensic_json_filename: str = "forensic.json", aggregate_csv_filename: str = "aggregate.csv", - forensic_csv_filename: str = "forensic.csv" + forensic_csv_filename: str = "forensic.csv", ) -> None: - """ - Save report data in the given directory + """Save report data in the given directory Args: results: Parsing results @@ -1395,17 +1433,19 @@ def save_output( else: os.makedirs(output_directory) - append_json(os.path.join(output_directory, aggregate_json_filename), - aggregate_reports) + append_json(os.path.join(output_directory, aggregate_json_filename), aggregate_reports) - append_csv(os.path.join(output_directory, aggregate_csv_filename), - parsed_aggregate_reports_to_csv(aggregate_reports)) + append_csv( + os.path.join(output_directory, aggregate_csv_filename), + parsed_aggregate_reports_to_csv(aggregate_reports), + ) - append_json(os.path.join(output_directory, forensic_json_filename), - forensic_reports) + append_json(os.path.join(output_directory, forensic_json_filename), forensic_reports) - append_csv(os.path.join(output_directory, forensic_csv_filename), - parsed_forensic_reports_to_csv(forensic_reports)) + append_csv( + os.path.join(output_directory, forensic_csv_filename), + parsed_forensic_reports_to_csv(forensic_reports), + ) samples_directory = os.path.join(output_directory, "samples") if not os.path.exists(samples_directory): @@ -1433,8 +1473,7 @@ def save_output( def get_report_zip(results: OrderedDict[str, List[OrderedDict]]) -> bytes: - """ - Creates a zip file of parsed report output + """Creates a zip file of parsed report output Args: results: The parsed results @@ -1442,6 +1481,7 @@ def get_report_zip(results: OrderedDict[str, List[OrderedDict]]) -> bytes: Returns: raw zip file """ + def add_subdir(root_path, subdir): subdir_path = os.path.join(root_path, subdir) for subdir_root, subdir_dirs, subdir_files in os.walk(subdir_path): @@ -1458,13 +1498,12 @@ def add_subdir(root_path, subdir): tmp_dir = tempfile.mkdtemp() try: save_output(results, tmp_dir) - with zipfile.ZipFile(storage, 'w', zipfile.ZIP_DEFLATED) as zip_file: + with zipfile.ZipFile(storage, "w", zipfile.ZIP_DEFLATED) as zip_file: for root, dirs, files in os.walk(tmp_dir): for file in files: file_path = os.path.join(root, file) if os.path.isfile(file_path): - arcname = os.path.join(os.path.relpath(root, tmp_dir), - file) + arcname = os.path.join(os.path.relpath(root, tmp_dir), file) zip_file.write(file_path, arcname) for directory in dirs: dir_path = os.path.join(root, directory) @@ -1491,10 +1530,9 @@ def email_results( password: Optional[str] = None, subject: Optional[str] = None, attachment_filename: Optional[str] = None, - message: Optional[str] = None + message: Optional[str] = None, ) -> None: - """ - Emails parsing results as a zip file + """Emails parsing results as a zip file Args: results: Parsing results @@ -1530,9 +1568,19 @@ def email_results( zip_bytes = get_report_zip(results) attachments = [(filename, zip_bytes)] - send_email(host, mail_from, mail_to, message_cc=mail_cc, - message_bcc=mail_bcc, port=port, - require_encryption=require_encryption, verify=verify, - username=username, password=password, subject=subject, - attachments=attachments, plain_message=message) + send_email( + host, + mail_from, + mail_to, + message_cc=mail_cc, + message_bcc=mail_bcc, + port=port, + require_encryption=require_encryption, + verify=verify, + username=username, + password=password, + subject=subject, + attachments=attachments, + plain_message=message, + ) return None diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index 445b3e40..f07201ae 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -17,10 +17,23 @@ import time from tqdm import tqdm -from parsedmarc import get_dmarc_reports_from_mailbox, watch_inbox, \ - parse_report_file, get_dmarc_reports_from_mbox, elastic, kafkaclient, \ - splunk, save_output, email_results, ParserError, __version__, \ - InvalidDMARCReport, s3, syslog, loganalytics +from parsedmarc import ( + get_dmarc_reports_from_mailbox, + watch_inbox, + parse_report_file, + get_dmarc_reports_from_mbox, + elastic, + kafkaclient, + splunk, + save_output, + email_results, + ParserError, + __version__, + InvalidDMARCReport, + s3, + syslog, + loganalytics, +) from parsedmarc.mail import IMAPConnection, MSGraphConnection, GmailConnection from parsedmarc.mail.graph import AuthMethod @@ -29,8 +42,8 @@ from parsedmarc.utils import is_mbox formatter = logging.Formatter( - fmt='%(levelname)8s:%(filename)s:%(lineno)d:%(message)s', - datefmt='%Y-%m-%d:%H:%M:%S') + fmt="%(levelname)8s:%(filename)s:%(lineno)d:%(message)s", datefmt="%Y-%m-%d:%H:%M:%S" +) handler = logging.StreamHandler() handler.setFormatter(formatter) logger.addHandler(handler) @@ -42,17 +55,18 @@ def _str_to_list(s): return list(map(lambda i: i.lstrip(), _list)) -def cli_parse(file_path, sa, nameservers, dns_timeout, - ip_db_path, offline, parallel=False): +def cli_parse(file_path, sa, nameservers, dns_timeout, ip_db_path, offline, parallel=False): """Separated this function for multiprocessing""" try: - file_results = parse_report_file(file_path, - ip_db_path=ip_db_path, - offline=offline, - nameservers=nameservers, - dns_timeout=dns_timeout, - strip_attachment_payloads=sa, - parallel=parallel) + file_results = parse_report_file( + file_path, + ip_db_path=ip_db_path, + offline=offline, + nameservers=nameservers, + dns_timeout=dns_timeout, + strip_attachment_payloads=sa, + parallel=parallel, + ) except ParserError as error: return error, file_path finally: @@ -71,18 +85,19 @@ def _main(): """Called when the module is executed""" def process_reports(reports_): - output_str = "{0}\n".format(json.dumps(reports_, - ensure_ascii=False, - indent=2)) + output_str = "{0}\n".format(json.dumps(reports_, ensure_ascii=False, indent=2)) if not opts.silent: print(output_str) if opts.output: - save_output(results, output_directory=opts.output, - aggregate_json_filename=opts.aggregate_json_filename, - forensic_json_filename=opts.forensic_json_filename, - aggregate_csv_filename=opts.aggregate_csv_filename, - forensic_csv_filename=opts.forensic_csv_filename) + save_output( + results, + output_directory=opts.output, + aggregate_json_filename=opts.aggregate_json_filename, + forensic_json_filename=opts.forensic_json_filename, + aggregate_csv_filename=opts.aggregate_csv_filename, + forensic_csv_filename=opts.forensic_csv_filename, + ) if opts.save_aggregate: for report in reports_["aggregate_reports"]: try: @@ -94,23 +109,19 @@ def process_reports(reports_): index_suffix=opts.elasticsearch_index_suffix, monthly_indexes=opts.elasticsearch_monthly_indexes, number_of_shards=shards, - number_of_replicas=replicas + number_of_replicas=replicas, ) except elastic.AlreadySaved as warning: logger.warning(warning.__str__()) except elastic.ElasticsearchError as error_: - logger.error("Elasticsearch Error: {0}".format( - error_.__str__())) + logger.error("Elasticsearch Error: {0}".format(error_.__str__())) except Exception as error_: - logger.error("Elasticsearch exception error: {}".format( - error_.__str__())) + logger.error("Elasticsearch exception error: {}".format(error_.__str__())) try: if opts.kafka_hosts: - kafka_client.save_aggregate_reports_to_kafka( - report, kafka_aggregate_topic) + kafka_client.save_aggregate_reports_to_kafka(report, kafka_aggregate_topic) except Exception as error_: - logger.error("Kafka Error: {0}".format( - error_.__str__())) + logger.error("Kafka Error: {0}".format(error_.__str__())) try: if opts.s3_bucket: s3_client.save_aggregate_report_to_s3(report) @@ -125,8 +136,7 @@ def process_reports(reports_): try: aggregate_reports_ = reports_["aggregate_reports"] if len(aggregate_reports_) > 0: - hec_client.save_aggregate_reports_to_splunk( - aggregate_reports_) + hec_client.save_aggregate_reports_to_splunk(aggregate_reports_) except splunk.SplunkError as e: logger.error("Splunk HEC error: {0}".format(e.__str__())) if opts.save_forensic: @@ -140,21 +150,19 @@ def process_reports(reports_): index_suffix=opts.elasticsearch_index_suffix, monthly_indexes=opts.elasticsearch_monthly_indexes, number_of_shards=shards, - number_of_replicas=replicas) + number_of_replicas=replicas, + ) except elastic.AlreadySaved as warning: logger.warning(warning.__str__()) except elastic.ElasticsearchError as error_: - logger.error("Elasticsearch Error: {0}".format( - error_.__str__())) + logger.error("Elasticsearch Error: {0}".format(error_.__str__())) except InvalidDMARCReport as error_: logger.error(error_.__str__()) try: if opts.kafka_hosts: - kafka_client.save_forensic_reports_to_kafka( - report, kafka_forensic_topic) + kafka_client.save_forensic_reports_to_kafka(report, kafka_forensic_topic) except Exception as error_: - logger.error("Kafka Error: {0}".format( - error_.__str__())) + logger.error("Kafka Error: {0}".format(error_.__str__())) try: if opts.s3_bucket: s3_client.save_forensic_report_to_s3(report) @@ -169,8 +177,7 @@ def process_reports(reports_): try: forensic_reports_ = reports_["forensic_reports"] if len(forensic_reports_) > 0: - hec_client.save_forensic_reports_to_splunk( - forensic_reports_) + hec_client.save_forensic_reports_to_splunk(forensic_reports_) except splunk.SplunkError as e: logger.error("Splunk HEC error: {0}".format(e.__str__())) if opts.la_dce: @@ -182,173 +189,180 @@ def process_reports(reports_): dce=opts.la_dce, dcr_immutable_id=opts.la_dcr_immutable_id, dcr_aggregate_stream=opts.la_dcr_aggregate_stream, - dcr_forensic_stream=opts.la_dcr_forensic_stream + dcr_forensic_stream=opts.la_dcr_forensic_stream, ) - la_client.publish_results( - reports_, - opts.save_aggregate, - opts.save_forensic) + la_client.publish_results(reports_, opts.save_aggregate, opts.save_forensic) except loganalytics.LogAnalyticsException as e: logger.error("Log Analytics error: {0}".format(e.__str__())) except Exception as e: logger.error( - "Unknown error occured" + - " during the publishing" + - " to Log Analitics: " + - e.__str__()) + "Unknown error occured" + + " during the publishing" + + " to Log Analitics: " + + e.__str__() + ) arg_parser = ArgumentParser(description="Parses DMARC reports") - arg_parser.add_argument("-c", "--config-file", - help="a path to a configuration file " - "(--silent implied)") - arg_parser.add_argument("file_path", nargs="*", - help="one or more paths to aggregate or forensic " - "report files, emails, or mbox files'") - strip_attachment_help = "remove attachment payloads from forensic " \ - "report output" - arg_parser.add_argument("--strip-attachment-payloads", - help=strip_attachment_help, action="store_true") - arg_parser.add_argument("-o", "--output", - help="write output files to the given directory") - arg_parser.add_argument("--aggregate-json-filename", - help="filename for the aggregate JSON output file", - default="aggregate.json") - arg_parser.add_argument("--forensic-json-filename", - help="filename for the forensic JSON output file", - default="forensic.json") - arg_parser.add_argument("--aggregate-csv-filename", - help="filename for the aggregate CSV output file", - default="aggregate.csv") - arg_parser.add_argument("--forensic-csv-filename", - help="filename for the forensic CSV output file", - default="forensic.csv") - arg_parser.add_argument("-n", "--nameservers", nargs="+", - help="nameservers to query") - arg_parser.add_argument("-t", "--dns_timeout", - help="number of seconds to wait for an answer " - "from DNS (default: 2.0)", - type=float, - default=2.0) - arg_parser.add_argument("--offline", action="store_true", - help="do not make online queries for geolocation " - " or DNS") - arg_parser.add_argument("-s", "--silent", action="store_true", - help="only print errors") - arg_parser.add_argument("-w", "--warnings", action="store_true", - help="print warnings in addition to errors") - arg_parser.add_argument("--verbose", action="store_true", - help="more verbose output") - arg_parser.add_argument("--debug", action="store_true", - help="print debugging information") - arg_parser.add_argument("--log-file", default=None, - help="output logging to a file") - arg_parser.add_argument("-v", "--version", action="version", - version=__version__) + arg_parser.add_argument( + "-c", "--config-file", help="a path to a configuration file " "(--silent implied)" + ) + arg_parser.add_argument( + "file_path", + nargs="*", + help="one or more paths to aggregate or forensic " "report files, emails, or mbox files'", + ) + strip_attachment_help = "remove attachment payloads from forensic " "report output" + arg_parser.add_argument( + "--strip-attachment-payloads", help=strip_attachment_help, action="store_true" + ) + arg_parser.add_argument("-o", "--output", help="write output files to the given directory") + arg_parser.add_argument( + "--aggregate-json-filename", + help="filename for the aggregate JSON output file", + default="aggregate.json", + ) + arg_parser.add_argument( + "--forensic-json-filename", + help="filename for the forensic JSON output file", + default="forensic.json", + ) + arg_parser.add_argument( + "--aggregate-csv-filename", + help="filename for the aggregate CSV output file", + default="aggregate.csv", + ) + arg_parser.add_argument( + "--forensic-csv-filename", + help="filename for the forensic CSV output file", + default="forensic.csv", + ) + arg_parser.add_argument("-n", "--nameservers", nargs="+", help="nameservers to query") + arg_parser.add_argument( + "-t", + "--dns_timeout", + help="number of seconds to wait for an answer " "from DNS (default: 2.0)", + type=float, + default=2.0, + ) + arg_parser.add_argument( + "--offline", + action="store_true", + help="do not make online queries for geolocation " " or DNS", + ) + arg_parser.add_argument("-s", "--silent", action="store_true", help="only print errors") + arg_parser.add_argument( + "-w", "--warnings", action="store_true", help="print warnings in addition to errors" + ) + arg_parser.add_argument("--verbose", action="store_true", help="more verbose output") + arg_parser.add_argument("--debug", action="store_true", help="print debugging information") + arg_parser.add_argument("--log-file", default=None, help="output logging to a file") + arg_parser.add_argument("-v", "--version", action="version", version=__version__) aggregate_reports = [] forensic_reports = [] args = arg_parser.parse_args() - default_gmail_api_scope = 'https://www.googleapis.com/auth/gmail.modify' - - opts = Namespace(file_path=args.file_path, - config_file=args.config_file, - offline=args.offline, - strip_attachment_payloads=args.strip_attachment_payloads, - output=args.output, - aggregate_csv_filename=args.aggregate_csv_filename, - aggregate_json_filename=args.aggregate_json_filename, - forensic_csv_filename=args.forensic_csv_filename, - forensic_json_filename=args.forensic_json_filename, - nameservers=args.nameservers, - silent=args.silent, - warnings=args.warnings, - dns_timeout=args.dns_timeout, - debug=args.debug, - verbose=args.verbose, - save_aggregate=False, - save_forensic=False, - mailbox_reports_folder="INBOX", - mailbox_archive_folder="Archive", - mailbox_watch=False, - mailbox_delete=False, - mailbox_test=False, - mailbox_batch_size=None, - mailbox_check_timeout=30, - imap_host=None, - imap_skip_certificate_verification=False, - imap_ssl=True, - imap_port=993, - imap_timeout=30, - imap_max_retries=4, - imap_user=None, - imap_password=None, - graph_auth_method=None, - graph_user=None, - graph_password=None, - graph_client_id=None, - graph_client_secret=None, - graph_tenant_id=None, - graph_mailbox=None, - graph_allow_unencrypted_storage=False, - hec=None, - hec_token=None, - hec_index=None, - hec_skip_certificate_verification=False, - elasticsearch_hosts=None, - elasticsearch_timeout=60, - elasticsearch_number_of_shards=1, - elasticsearch_number_of_replicas=0, - elasticsearch_index_suffix=None, - elasticsearch_ssl=True, - elasticsearch_ssl_cert_path=None, - elasticsearch_monthly_indexes=False, - elasticsearch_username=None, - elasticsearch_password=None, - elasticsearch_apiKey=None, - kafka_hosts=None, - kafka_username=None, - kafka_password=None, - kafka_aggregate_topic=None, - kafka_forensic_topic=None, - kafka_ssl=False, - kafka_skip_certificate_verification=False, - smtp_host=None, - smtp_port=25, - smtp_ssl=False, - smtp_skip_certificate_verification=False, - smtp_user=None, - smtp_password=None, - smtp_from=None, - smtp_to=[], - smtp_subject="parsedmarc report", - smtp_message="Please see the attached DMARC results.", - s3_bucket=None, - s3_path=None, - s3_region_name=None, - s3_endpoint_url=None, - s3_access_key_id=None, - s3_secret_access_key=None, - syslog_server=None, - syslog_port=None, - gmail_api_credentials_file=None, - gmail_api_token_file=None, - gmail_api_include_spam_trash=False, - gmail_api_scopes=[], - gmail_api_oauth2_port=8080, - log_file=args.log_file, - n_procs=1, - chunk_size=1, - ip_db_path=None, - la_client_id=None, - la_client_secret=None, - la_tenant_id=None, - la_dce=None, - la_dcr_immutable_id=None, - la_dcr_aggregate_stream=None, - la_dcr_forensic_stream=None - ) + default_gmail_api_scope = "https://www.googleapis.com/auth/gmail.modify" + + opts = Namespace( + file_path=args.file_path, + config_file=args.config_file, + offline=args.offline, + strip_attachment_payloads=args.strip_attachment_payloads, + output=args.output, + aggregate_csv_filename=args.aggregate_csv_filename, + aggregate_json_filename=args.aggregate_json_filename, + forensic_csv_filename=args.forensic_csv_filename, + forensic_json_filename=args.forensic_json_filename, + nameservers=args.nameservers, + silent=args.silent, + warnings=args.warnings, + dns_timeout=args.dns_timeout, + debug=args.debug, + verbose=args.verbose, + save_aggregate=False, + save_forensic=False, + mailbox_reports_folder="INBOX", + mailbox_archive_folder="Archive", + mailbox_watch=False, + mailbox_delete=False, + mailbox_test=False, + mailbox_batch_size=None, + mailbox_check_timeout=30, + imap_host=None, + imap_skip_certificate_verification=False, + imap_ssl=True, + imap_port=993, + imap_timeout=30, + imap_max_retries=4, + imap_user=None, + imap_password=None, + graph_auth_method=None, + graph_user=None, + graph_password=None, + graph_client_id=None, + graph_client_secret=None, + graph_tenant_id=None, + graph_mailbox=None, + graph_allow_unencrypted_storage=False, + hec=None, + hec_token=None, + hec_index=None, + hec_skip_certificate_verification=False, + elasticsearch_hosts=None, + elasticsearch_timeout=60, + elasticsearch_number_of_shards=1, + elasticsearch_number_of_replicas=0, + elasticsearch_index_suffix=None, + elasticsearch_ssl=True, + elasticsearch_ssl_cert_path=None, + elasticsearch_monthly_indexes=False, + elasticsearch_username=None, + elasticsearch_password=None, + elasticsearch_apiKey=None, + kafka_hosts=None, + kafka_username=None, + kafka_password=None, + kafka_aggregate_topic=None, + kafka_forensic_topic=None, + kafka_ssl=False, + kafka_skip_certificate_verification=False, + smtp_host=None, + smtp_port=25, + smtp_ssl=False, + smtp_skip_certificate_verification=False, + smtp_user=None, + smtp_password=None, + smtp_from=None, + smtp_to=[], + smtp_subject="parsedmarc report", + smtp_message="Please see the attached DMARC results.", + s3_bucket=None, + s3_path=None, + s3_region_name=None, + s3_endpoint_url=None, + s3_access_key_id=None, + s3_secret_access_key=None, + syslog_server=None, + syslog_port=None, + gmail_api_credentials_file=None, + gmail_api_token_file=None, + gmail_api_include_spam_trash=False, + gmail_api_scopes=[], + gmail_api_oauth2_port=8080, + log_file=args.log_file, + n_procs=1, + chunk_size=1, + ip_db_path=None, + la_client_id=None, + la_client_secret=None, + la_tenant_id=None, + la_dce=None, + la_dcr_immutable_id=None, + la_dcr_aggregate_stream=None, + la_dcr_forensic_stream=None, + ) args = arg_parser.parse_args() if args.config_file: @@ -365,21 +379,18 @@ def process_reports(reports_): opts.offline = general_config.getboolean("offline") if "strip_attachment_payloads" in general_config: opts.strip_attachment_payloads = general_config.getboolean( - "strip_attachment_payloads") + "strip_attachment_payloads" + ) if "output" in general_config: opts.output = general_config["output"] if "aggregate_json_filename" in general_config: - opts.aggregate_json_filename = general_config[ - "aggregate_json_filename"] + opts.aggregate_json_filename = general_config["aggregate_json_filename"] if "forensic_json_filename" in general_config: - opts.forensic_json_filename = general_config[ - "forensic_json_filename"] + opts.forensic_json_filename = general_config["forensic_json_filename"] if "aggregate_csv_filename" in general_config: - opts.aggregate_csv_filename = general_config[ - "aggregate_csv_filename"] + opts.aggregate_csv_filename = general_config["aggregate_csv_filename"] if "forensic_csv_filename" in general_config: - opts.forensic_csv_filename = general_config[ - "forensic_csv_filename"] + opts.forensic_csv_filename = general_config["forensic_csv_filename"] if "nameservers" in general_config: opts.nameservers = _str_to_list(general_config["nameservers"]) if "dns_timeout" in general_config: @@ -424,20 +435,20 @@ def process_reports(reports_): if "batch_size" in mailbox_config: opts.mailbox_batch_size = mailbox_config.getint("batch_size") if "check_timeout" in mailbox_config: - opts.mailbox_check_timeout = mailbox_config.getint( - "check_timeout") + opts.mailbox_check_timeout = mailbox_config.getint("check_timeout") if "imap" in config.sections(): imap_config = config["imap"] if "watch" in imap_config: - logger.warning("Starting in 8.0.0, the watch option has been " - "moved from the imap configuration section to " - "the mailbox configuration section.") + logger.warning( + "Starting in 8.0.0, the watch option has been " + "moved from the imap configuration section to " + "the mailbox configuration section." + ) if "host" in imap_config: opts.imap_host = imap_config["host"] else: - logger.error("host setting missing from the " - "imap config section") + logger.error("host setting missing from the " "imap config section") exit(-1) if "port" in imap_config: opts.imap_port = imap_config.getint("port") @@ -448,65 +459,76 @@ def process_reports(reports_): if "ssl" in imap_config: opts.imap_ssl = imap_config.getboolean("ssl") if "skip_certificate_verification" in imap_config: - imap_verify = imap_config.getboolean( - "skip_certificate_verification") + imap_verify = imap_config.getboolean("skip_certificate_verification") opts.imap_skip_certificate_verification = imap_verify if "user" in imap_config: opts.imap_user = imap_config["user"] else: - logger.critical("user setting missing from the " - "imap config section") + logger.critical("user setting missing from the " "imap config section") exit(-1) if "password" in imap_config: opts.imap_password = imap_config["password"] else: - logger.critical("password setting missing from the " - "imap config section") + logger.critical("password setting missing from the " "imap config section") exit(-1) if "reports_folder" in imap_config: opts.mailbox_reports_folder = imap_config["reports_folder"] - logger.warning("Use of the reports_folder option in the imap " - "configuration section has been deprecated. " - "Use this option in the mailbox configuration " - "section instead.") + logger.warning( + "Use of the reports_folder option in the imap " + "configuration section has been deprecated. " + "Use this option in the mailbox configuration " + "section instead." + ) if "archive_folder" in imap_config: opts.mailbox_archive_folder = imap_config["archive_folder"] - logger.warning("Use of the archive_folder option in the imap " - "configuration section has been deprecated. " - "Use this option in the mailbox configuration " - "section instead.") + logger.warning( + "Use of the archive_folder option in the imap " + "configuration section has been deprecated. " + "Use this option in the mailbox configuration " + "section instead." + ) if "watch" in imap_config: opts.mailbox_watch = imap_config.getboolean("watch") - logger.warning("Use of the watch option in the imap " - "configuration section has been deprecated. " - "Use this option in the mailbox configuration " - "section instead.") + logger.warning( + "Use of the watch option in the imap " + "configuration section has been deprecated. " + "Use this option in the mailbox configuration " + "section instead." + ) if "delete" in imap_config: - logger.warning("Use of the delete option in the imap " - "configuration section has been deprecated. " - "Use this option in the mailbox configuration " - "section instead.") + logger.warning( + "Use of the delete option in the imap " + "configuration section has been deprecated. " + "Use this option in the mailbox configuration " + "section instead." + ) if "test" in imap_config: opts.mailbox_test = imap_config.getboolean("test") - logger.warning("Use of the test option in the imap " - "configuration section has been deprecated. " - "Use this option in the mailbox configuration " - "section instead.") + logger.warning( + "Use of the test option in the imap " + "configuration section has been deprecated. " + "Use this option in the mailbox configuration " + "section instead." + ) if "batch_size" in imap_config: opts.mailbox_batch_size = imap_config.getint("batch_size") - logger.warning("Use of the batch_size option in the imap " - "configuration section has been deprecated. " - "Use this option in the mailbox configuration " - "section instead.") + logger.warning( + "Use of the batch_size option in the imap " + "configuration section has been deprecated. " + "Use this option in the mailbox configuration " + "section instead." + ) if "msgraph" in config.sections(): graph_config = config["msgraph"] opts.graph_token_file = graph_config.get("token_file", ".token") if "auth_method" not in graph_config: - logger.info("auth_method setting missing from the " - "msgraph config section " - "defaulting to UsernamePassword") + logger.info( + "auth_method setting missing from the " + "msgraph config section " + "defaulting to UsernamePassword" + ) opts.graph_auth_method = AuthMethod.UsernamePassword.name else: opts.graph_auth_method = graph_config["auth_method"] @@ -515,188 +537,159 @@ def process_reports(reports_): if "user" in graph_config: opts.graph_user = graph_config["user"] else: - logger.critical("user setting missing from the " - "msgraph config section") + logger.critical("user setting missing from the " "msgraph config section") exit(-1) if "password" in graph_config: opts.graph_password = graph_config["password"] else: - logger.critical("password setting missing from the " - "msgraph config section") + logger.critical("password setting missing from the " "msgraph config section") exit(-1) if opts.graph_auth_method != AuthMethod.UsernamePassword.name: if "tenant_id" in graph_config: - opts.graph_tenant_id = graph_config['tenant_id'] + opts.graph_tenant_id = graph_config["tenant_id"] else: - logger.critical("tenant_id setting missing from the " - "msgraph config section") + logger.critical("tenant_id setting missing from the " "msgraph config section") exit(-1) if "client_secret" in graph_config: opts.graph_client_secret = graph_config["client_secret"] else: - logger.critical("client_secret setting missing from the " - "msgraph config section") + logger.critical("client_secret setting missing from the " "msgraph config section") exit(-1) if "client_id" in graph_config: opts.graph_client_id = graph_config["client_id"] else: - logger.critical("client_id setting missing from the " - "msgraph config section") + logger.critical("client_id setting missing from the " "msgraph config section") exit(-1) if "mailbox" in graph_config: opts.graph_mailbox = graph_config["mailbox"] elif opts.graph_auth_method != AuthMethod.UsernamePassword.name: - logger.critical("mailbox setting missing from the " - "msgraph config section") + logger.critical("mailbox setting missing from the " "msgraph config section") exit(-1) if "allow_unencrypted_storage" in graph_config: opts.graph_allow_unencrypted_storage = graph_config.getboolean( - "allow_unencrypted_storage") + "allow_unencrypted_storage" + ) if "elasticsearch" in config: elasticsearch_config = config["elasticsearch"] if "hosts" in elasticsearch_config: - opts.elasticsearch_hosts = _str_to_list(elasticsearch_config[ - "hosts"]) + opts.elasticsearch_hosts = _str_to_list(elasticsearch_config["hosts"]) else: - logger.critical("hosts setting missing from the " - "elasticsearch config section") + logger.critical("hosts setting missing from the " "elasticsearch config section") exit(-1) if "timeout" in elasticsearch_config: timeout = elasticsearch_config.getfloat("timeout") opts.elasticsearch_timeout = timeout if "number_of_shards" in elasticsearch_config: - number_of_shards = elasticsearch_config.getint( - "number_of_shards") + number_of_shards = elasticsearch_config.getint("number_of_shards") opts.elasticsearch_number_of_shards = number_of_shards if "number_of_replicas" in elasticsearch_config: - number_of_replicas = elasticsearch_config.getint( - "number_of_replicas") + number_of_replicas = elasticsearch_config.getint("number_of_replicas") opts.elasticsearch_number_of_replicas = number_of_replicas if "index_suffix" in elasticsearch_config: - opts.elasticsearch_index_suffix = elasticsearch_config[ - "index_suffix"] + opts.elasticsearch_index_suffix = elasticsearch_config["index_suffix"] if "monthly_indexes" in elasticsearch_config: monthly = elasticsearch_config.getboolean("monthly_indexes") opts.elasticsearch_monthly_indexes = monthly if "ssl" in elasticsearch_config: - opts.elasticsearch_ssl = elasticsearch_config.getboolean( - "ssl") + opts.elasticsearch_ssl = elasticsearch_config.getboolean("ssl") if "cert_path" in elasticsearch_config: - opts.elasticsearch_ssl_cert_path = elasticsearch_config[ - "cert_path"] + opts.elasticsearch_ssl_cert_path = elasticsearch_config["cert_path"] if "user" in elasticsearch_config: - opts.elasticsearch_username = elasticsearch_config[ - "user"] + opts.elasticsearch_username = elasticsearch_config["user"] if "password" in elasticsearch_config: - opts.elasticsearch_password = elasticsearch_config[ - "password"] + opts.elasticsearch_password = elasticsearch_config["password"] if "apiKey" in elasticsearch_config: - opts.elasticsearch_apiKey = elasticsearch_config[ - "apiKey"] + opts.elasticsearch_apiKey = elasticsearch_config["apiKey"] if "splunk_hec" in config.sections(): hec_config = config["splunk_hec"] if "url" in hec_config: opts.hec = hec_config["url"] else: - logger.critical("url setting missing from the " - "splunk_hec config section") + logger.critical("url setting missing from the " "splunk_hec config section") exit(-1) if "token" in hec_config: opts.hec_token = hec_config["token"] else: - logger.critical("token setting missing from the " - "splunk_hec config section") + logger.critical("token setting missing from the " "splunk_hec config section") exit(-1) if "index" in hec_config: opts.hec_index = hec_config["index"] else: - logger.critical("index setting missing from the " - "splunk_hec config section") + logger.critical("index setting missing from the " "splunk_hec config section") exit(-1) if "skip_certificate_verification" in hec_config: - opts.hec_skip_certificate_verification = hec_config[ - "skip_certificate_verification"] + opts.hec_skip_certificate_verification = hec_config["skip_certificate_verification"] if "kafka" in config.sections(): kafka_config = config["kafka"] if "hosts" in kafka_config: opts.kafka_hosts = _str_to_list(kafka_config["hosts"]) else: - logger.critical("hosts setting missing from the " - "kafka config section") + logger.critical("hosts setting missing from the " "kafka config section") exit(-1) if "user" in kafka_config: opts.kafka_username = kafka_config["user"] else: - logger.critical("user setting missing from the " - "kafka config section") + logger.critical("user setting missing from the " "kafka config section") exit(-1) if "password" in kafka_config: opts.kafka_password = kafka_config["password"] else: - logger.critical("password setting missing from the " - "kafka config section") + logger.critical("password setting missing from the " "kafka config section") exit(-1) if "ssl" in kafka_config: opts.kafka_ssl = kafka_config.getboolean("ssl") if "skip_certificate_verification" in kafka_config: - kafka_verify = kafka_config.getboolean( - "skip_certificate_verification") + kafka_verify = kafka_config.getboolean("skip_certificate_verification") opts.kafka_skip_certificate_verification = kafka_verify if "aggregate_topic" in kafka_config: opts.kafka_aggregate = kafka_config["aggregate_topic"] else: - logger.critical("aggregate_topic setting missing from the " - "kafka config section") + logger.critical("aggregate_topic setting missing from the " "kafka config section") exit(-1) if "forensic_topic" in kafka_config: opts.kafka_username = kafka_config["forensic_topic"] else: - logger.critical("forensic_topic setting missing from the " - "splunk_hec config section") + logger.critical( + "forensic_topic setting missing from the " "splunk_hec config section" + ) if "smtp" in config.sections(): smtp_config = config["smtp"] if "host" in smtp_config: opts.smtp_host = smtp_config["host"] else: - logger.critical("host setting missing from the " - "smtp config section") + logger.critical("host setting missing from the " "smtp config section") exit(-1) if "port" in smtp_config: opts.smtp_port = smtp_config.getint("port") if "ssl" in smtp_config: opts.smtp_ssl = smtp_config.getboolean("ssl") if "skip_certificate_verification" in smtp_config: - smtp_verify = smtp_config.getboolean( - "skip_certificate_verification") + smtp_verify = smtp_config.getboolean("skip_certificate_verification") opts.smtp_skip_certificate_verification = smtp_verify if "user" in smtp_config: opts.smtp_user = smtp_config["user"] else: - logger.critical("user setting missing from the " - "smtp config section") + logger.critical("user setting missing from the " "smtp config section") exit(-1) if "password" in smtp_config: opts.smtp_password = smtp_config["password"] else: - logger.critical("password setting missing from the " - "smtp config section") + logger.critical("password setting missing from the " "smtp config section") exit(-1) if "from" in smtp_config: opts.smtp_from = smtp_config["from"] else: - logger.critical("from setting missing from the " - "smtp config section") + logger.critical("from setting missing from the " "smtp config section") if "to" in smtp_config: opts.smtp_to = _str_to_list(smtp_config["to"]) else: - logger.critical("to setting missing from the " - "smtp config section") + logger.critical("to setting missing from the " "smtp config section") if "subject" in smtp_config: opts.smtp_subject = smtp_config["subject"] if "attachment" in smtp_config: @@ -708,8 +701,7 @@ def process_reports(reports_): if "bucket" in s3_config: opts.s3_bucket = s3_config["bucket"] else: - logger.critical("bucket setting missing from the " - "s3 config section") + logger.critical("bucket setting missing from the " "s3 config section") exit(-1) if "path" in s3_config: opts.s3_path = s3_config["path"] @@ -734,8 +726,7 @@ def process_reports(reports_): if "server" in syslog_config: opts.syslog_server = syslog_config["server"] else: - logger.critical("server setting missing from the " - "syslog config section") + logger.critical("server setting missing from the " "syslog config section") exit(-1) if "port" in syslog_config: opts.syslog_port = syslog_config["port"] @@ -744,36 +735,24 @@ def process_reports(reports_): if "gmail_api" in config.sections(): gmail_api_config = config["gmail_api"] - opts.gmail_api_credentials_file = \ - gmail_api_config.get("credentials_file") - opts.gmail_api_token_file = \ - gmail_api_config.get("token_file", ".token") - opts.gmail_api_include_spam_trash = \ - gmail_api_config.getboolean("include_spam_trash", False) - opts.gmail_api_scopes = \ - gmail_api_config.get("scopes", - default_gmail_api_scope) - opts.gmail_api_scopes = \ - _str_to_list(opts.gmail_api_scopes) + opts.gmail_api_credentials_file = gmail_api_config.get("credentials_file") + opts.gmail_api_token_file = gmail_api_config.get("token_file", ".token") + opts.gmail_api_include_spam_trash = gmail_api_config.getboolean( + "include_spam_trash", False + ) + opts.gmail_api_scopes = gmail_api_config.get("scopes", default_gmail_api_scope) + opts.gmail_api_scopes = _str_to_list(opts.gmail_api_scopes) if "oauth2_port" in gmail_api_config: - opts.gmail_api_oauth2_port = \ - gmail_api_config.get("oauth2_port", 8080) + opts.gmail_api_oauth2_port = gmail_api_config.get("oauth2_port", 8080) if "log_analytics" in config.sections(): log_analytics_config = config["log_analytics"] - opts.la_client_id = \ - log_analytics_config.get("client_id") - opts.la_client_secret = \ - log_analytics_config.get("client_secret") - opts.la_tenant_id = \ - log_analytics_config.get("tenant_id") - opts.la_dce = \ - log_analytics_config.get("dce") - opts.la_dcr_immutable_id = \ - log_analytics_config.get("dcr_immutable_id") - opts.la_dcr_aggregate_stream = \ - log_analytics_config.get("dcr_aggregate_stream") - opts.la_dcr_forensic_stream = \ - log_analytics_config.get("dcr_forensic_stream") + opts.la_client_id = log_analytics_config.get("client_id") + opts.la_client_secret = log_analytics_config.get("client_secret") + opts.la_tenant_id = log_analytics_config.get("tenant_id") + opts.la_dce = log_analytics_config.get("dce") + opts.la_dcr_immutable_id = log_analytics_config.get("dcr_immutable_id") + opts.la_dcr_aggregate_stream = log_analytics_config.get("dcr_aggregate_stream") + opts.la_dcr_forensic_stream = log_analytics_config.get("dcr_forensic_stream") logger.setLevel(logging.ERROR) @@ -789,17 +768,19 @@ def process_reports(reports_): log_file.close() fh = logging.FileHandler(opts.log_file) formatter = logging.Formatter( - '%(asctime)s - ' - '%(levelname)s - [%(filename)s:%(lineno)d] - %(message)s') + "%(asctime)s - " "%(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" + ) fh.setFormatter(formatter) logger.addHandler(fh) except Exception as error: logger.warning("Unable to write to log file: {}".format(error)) - if opts.imap_host is None \ - and opts.graph_client_id is None \ - and opts.gmail_api_credentials_file is None \ - and len(opts.file_path) == 0: + if ( + opts.imap_host is None + and opts.graph_client_id is None + and opts.gmail_api_credentials_file is None + and len(opts.file_path) == 0 + ): logger.error("You must supply input files or a mailbox connection") exit(1) @@ -812,19 +793,20 @@ def process_reports(reports_): es_forensic_index = "dmarc_forensic" if opts.elasticsearch_index_suffix: suffix = opts.elasticsearch_index_suffix - es_aggregate_index = "{0}_{1}".format( - es_aggregate_index, suffix) - es_forensic_index = "{0}_{1}".format( - es_forensic_index, suffix) - elastic.set_hosts(opts.elasticsearch_hosts, - opts.elasticsearch_ssl, - opts.elasticsearch_ssl_cert_path, - opts.elasticsearch_username, - opts.elasticsearch_password, - opts.elasticsearch_apiKey, - timeout=opts.elasticsearch_timeout) - elastic.migrate_indexes(aggregate_indexes=[es_aggregate_index], - forensic_indexes=[es_forensic_index]) + es_aggregate_index = "{0}_{1}".format(es_aggregate_index, suffix) + es_forensic_index = "{0}_{1}".format(es_forensic_index, suffix) + elastic.set_hosts( + opts.elasticsearch_hosts, + opts.elasticsearch_ssl, + opts.elasticsearch_ssl_cert_path, + opts.elasticsearch_username, + opts.elasticsearch_password, + opts.elasticsearch_apiKey, + timeout=opts.elasticsearch_timeout, + ) + elastic.migrate_indexes( + aggregate_indexes=[es_aggregate_index], forensic_indexes=[es_forensic_index] + ) except elastic.ElasticsearchError: logger.exception("Elasticsearch Error") exit(1) @@ -853,16 +835,13 @@ def process_reports(reports_): if opts.hec: if opts.hec_token is None or opts.hec_index is None: - logger.error("HEC token and HEC index are required when " - "using HEC URL") + logger.error("HEC token and HEC index are required when " "using HEC URL") exit(1) verify = True if opts.hec_skip_certificate_verification: verify = False - hec_client = splunk.HECClient(opts.hec, opts.hec_token, - opts.hec_index, - verify=verify) + hec_client = splunk.HECClient(opts.hec, opts.hec_token, opts.hec_index, verify=verify) if opts.kafka_hosts: try: @@ -876,7 +855,7 @@ def process_reports(reports_): opts.kafka_hosts, username=opts.kafka_username, password=opts.kafka_password, - ssl_context=ssl_context + ssl_context=ssl_context, ) except Exception as error_: logger.error("Kafka Error: {0}".format(error_.__str__())) @@ -899,17 +878,21 @@ def process_reports(reports_): for mbox_path in mbox_paths: file_paths.remove(mbox_path) - counter = Value('i', 0) + counter = Value("i", 0) pool = Pool(opts.n_procs, initializer=init, initargs=(counter,)) - results = pool.starmap_async(cli_parse, - zip(file_paths, - repeat(opts.strip_attachment_payloads), - repeat(opts.nameservers), - repeat(opts.dns_timeout), - repeat(opts.ip_db_path), - repeat(opts.offline), - repeat(opts.n_procs >= 1)), - opts.chunk_size) + results = pool.starmap_async( + cli_parse, + zip( + file_paths, + repeat(opts.strip_attachment_payloads), + repeat(opts.nameservers), + repeat(opts.dns_timeout), + repeat(opts.ip_db_path), + repeat(opts.offline), + repeat(opts.n_procs >= 1), + ), + opts.chunk_size, + ) if sys.stdout.isatty(): pbar = tqdm(total=len(file_paths)) while not results.ready(): @@ -925,8 +908,7 @@ def process_reports(reports_): for result in results: if type(result[0]) is InvalidDMARCReport: - logger.error("Failed to parse {0} - {1}".format(result[1], - result[0])) + logger.error("Failed to parse {0} - {1}".format(result[1], result[0])) else: if result[0]["report_type"] == "aggregate": aggregate_reports.append(result[0]["report"]) @@ -935,13 +917,15 @@ def process_reports(reports_): for mbox_path in mbox_paths: strip = opts.strip_attachment_payloads - reports = get_dmarc_reports_from_mbox(mbox_path, - nameservers=opts.nameservers, - dns_timeout=opts.dns_timeout, - strip_attachment_payloads=strip, - ip_db_path=opts.ip_db_path, - offline=opts.offline, - parallel=False) + reports = get_dmarc_reports_from_mbox( + mbox_path, + nameservers=opts.nameservers, + dns_timeout=opts.dns_timeout, + strip_attachment_payloads=strip, + ip_db_path=opts.ip_db_path, + offline=opts.offline, + parallel=False, + ) aggregate_reports += reports["aggregate_reports"] forensic_reports += reports["forensic_reports"] @@ -949,8 +933,7 @@ def process_reports(reports_): if opts.imap_host: try: if opts.imap_user is None or opts.imap_password is None: - logger.error("IMAP user and password must be specified if" - "host is specified") + logger.error("IMAP user and password must be specified if" "host is specified") ssl = True verify = True @@ -987,7 +970,7 @@ def process_reports(reports_): username=opts.graph_user, password=opts.graph_password, token_file=opts.graph_token_file, - allow_unencrypted_storage=opts.graph_allow_unencrypted_storage + allow_unencrypted_storage=opts.graph_allow_unencrypted_storage, ) except Exception: @@ -996,11 +979,13 @@ def process_reports(reports_): if opts.gmail_api_credentials_file: if opts.mailbox_delete: - if 'https://mail.google.com/' not in opts.gmail_api_scopes: - logger.error("Message deletion requires scope" - " 'https://mail.google.com/'. " - "Add the scope and remove token file " - "to acquire proper access.") + if "https://mail.google.com/" not in opts.gmail_api_scopes: + logger.error( + "Message deletion requires scope" + " 'https://mail.google.com/'. " + "Add the scope and remove token file " + "to acquire proper access." + ) opts.mailbox_delete = False try: @@ -1010,7 +995,7 @@ def process_reports(reports_): scopes=opts.gmail_api_scopes, include_spam_trash=opts.gmail_api_include_spam_trash, reports_folder=opts.mailbox_reports_folder, - oauth2_port=opts.gmail_api_oauth2_port + oauth2_port=opts.gmail_api_oauth2_port, ) except Exception: @@ -1039,8 +1024,9 @@ def process_reports(reports_): logger.exception("Mailbox Error") exit(1) - results = OrderedDict([("aggregate_reports", aggregate_reports), - ("forensic_reports", forensic_reports)]) + results = OrderedDict( + [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] + ) process_reports(results) @@ -1049,11 +1035,17 @@ def process_reports(reports_): verify = True if opts.smtp_skip_certificate_verification: verify = False - email_results(results, opts.smtp_host, opts.smtp_from, - opts.smtp_to, port=opts.smtp_port, verify=verify, - username=opts.smtp_user, - password=opts.smtp_password, - subject=opts.smtp_subject) + email_results( + results, + opts.smtp_host, + opts.smtp_from, + opts.smtp_to, + port=opts.smtp_port, + verify=verify, + username=opts.smtp_user, + password=opts.smtp_password, + subject=opts.smtp_subject, + ) except Exception: logger.exception("Failed to email results") exit(1) @@ -1075,7 +1067,8 @@ def process_reports(reports_): strip_attachment_payloads=opts.strip_attachment_payloads, batch_size=opts.mailbox_batch_size, ip_db_path=opts.ip_db_path, - offline=opts.offline) + offline=opts.offline, + ) except FileExistsError as error: logger.error("{0}".format(error.__str__())) exit(1) diff --git a/parsedmarc/elastic.py b/parsedmarc/elastic.py index 02198c9b..14a5e817 100644 --- a/parsedmarc/elastic.py +++ b/parsedmarc/elastic.py @@ -3,8 +3,20 @@ from collections import OrderedDict from elasticsearch_dsl.search import Q -from elasticsearch_dsl import connections, Object, Document, Index, Nested, \ - InnerDoc, Integer, Text, Boolean, Ip, Date, Search +from elasticsearch_dsl import ( + connections, + Object, + Document, + Index, + Nested, + InnerDoc, + Integer, + Text, + Boolean, + Ip, + Date, + Search, +) from elasticsearch.helpers import reindex from parsedmarc.log import logger @@ -74,24 +86,19 @@ class Index: spf_results = Nested(_SPFResult) def add_policy_override(self, type_, comment): - self.policy_overrides.append(_PolicyOverride(type=type_, - comment=comment)) + self.policy_overrides.append(_PolicyOverride(type=type_, comment=comment)) def add_dkim_result(self, domain, selector, result): - self.dkim_results.append(_DKIMResult(domain=domain, - selector=selector, - result=result)) + self.dkim_results.append(_DKIMResult(domain=domain, selector=selector, result=result)) def add_spf_result(self, domain, scope, result): - self.spf_results.append(_SPFResult(domain=domain, - scope=scope, - result=result)) + self.spf_results.append(_SPFResult(domain=domain, scope=scope, result=result)) - def save(self, ** kwargs): + def save(self, **kwargs): self.passed_dmarc = False self.passed_dmarc = self.spf_aligned or self.dkim_aligned - return super().save(** kwargs) + return super().save(**kwargs) class _EmailAddressDoc(InnerDoc): @@ -121,24 +128,21 @@ class _ForensicSampleDoc(InnerDoc): attachments = Nested(_EmailAttachmentDoc) def add_to(self, display_name, address): - self.to.append(_EmailAddressDoc(display_name=display_name, - address=address)) + self.to.append(_EmailAddressDoc(display_name=display_name, address=address)) def add_reply_to(self, display_name, address): - self.reply_to.append(_EmailAddressDoc(display_name=display_name, - address=address)) + self.reply_to.append(_EmailAddressDoc(display_name=display_name, address=address)) def add_cc(self, display_name, address): - self.cc.append(_EmailAddressDoc(display_name=display_name, - address=address)) + self.cc.append(_EmailAddressDoc(display_name=display_name, address=address)) def add_bcc(self, display_name, address): - self.bcc.append(_EmailAddressDoc(display_name=display_name, - address=address)) + self.bcc.append(_EmailAddressDoc(display_name=display_name, address=address)) def add_attachment(self, filename, content_type, sha256): - self.attachments.append(_EmailAttachmentDoc(filename=filename, - content_type=content_type, sha256=sha256)) + self.attachments.append( + _EmailAttachmentDoc(filename=filename, content_type=content_type, sha256=sha256) + ) class _ForensicReportDoc(Document): @@ -168,8 +172,15 @@ class AlreadySaved(ValueError): """Raised when a report to be saved matches an existing report""" -def set_hosts(hosts, use_ssl=False, ssl_cert_path=None, - username=None, password=None, apiKey=None, timeout=60.0): +def set_hosts( + hosts, + use_ssl=False, + ssl_cert_path=None, + username=None, + password=None, + apiKey=None, + timeout=60.0, +): """ Sets the Elasticsearch hosts to use @@ -184,21 +195,18 @@ def set_hosts(hosts, use_ssl=False, ssl_cert_path=None, """ if not isinstance(hosts, list): hosts = [hosts] - conn_params = { - "hosts": hosts, - "timeout": timeout - } + conn_params = {"hosts": hosts, "timeout": timeout} if use_ssl: - conn_params['use_ssl'] = True + conn_params["use_ssl"] = True if ssl_cert_path: - conn_params['verify_certs'] = True - conn_params['ca_certs'] = ssl_cert_path + conn_params["verify_certs"] = True + conn_params["ca_certs"] = ssl_cert_path else: - conn_params['verify_certs'] = False + conn_params["verify_certs"] = False if username: - conn_params['http_auth'] = (username+":"+password) + conn_params["http_auth"] = username + ":" + password if apiKey: - conn_params['api_key'] = apiKey + conn_params["api_key"] = apiKey connections.create_connection(**conn_params) @@ -217,14 +225,12 @@ def create_indexes(names, settings=None): if not index.exists(): logger.debug("Creating Elasticsearch index: {0}".format(name)) if settings is None: - index.settings(number_of_shards=1, - number_of_replicas=0) + index.settings(number_of_shards=1, number_of_replicas=0) else: index.settings(**settings) index.create() except Exception as e: - raise ElasticsearchError( - "Elasticsearch error: {0}".format(e.__str__())) + raise ElasticsearchError("Elasticsearch error: {0}".format(e.__str__())) def migrate_indexes(aggregate_indexes=None, forensic_indexes=None): @@ -256,32 +262,30 @@ def migrate_indexes(aggregate_indexes=None, forensic_indexes=None): fo_type = fo_mapping["type"] if fo_type == "long": new_index_name = "{0}-v{1}".format(aggregate_index_name, version) - body = {"properties": {"published_policy.fo": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword", - "ignore_above": 256 + body = { + "properties": { + "published_policy.fo": { + "type": "text", + "fields": {"keyword": {"type": "keyword", "ignore_above": 256}}, } } } - } - } Index(new_index_name).create() Index(new_index_name).put_mapping(doc_type=doc, body=body) - reindex(connections.get_connection(), aggregate_index_name, - new_index_name) + reindex(connections.get_connection(), aggregate_index_name, new_index_name) Index(aggregate_index_name).delete() for forensic_index in forensic_indexes: pass -def save_aggregate_report_to_elasticsearch(aggregate_report, - index_suffix=None, - monthly_indexes=False, - number_of_shards=1, - number_of_replicas=0): +def save_aggregate_report_to_elasticsearch( + aggregate_report, + index_suffix=None, + monthly_indexes=False, + number_of_shards=1, + number_of_replicas=0, +): """ Saves a parsed DMARC aggregate report to ElasticSearch @@ -301,10 +305,8 @@ def save_aggregate_report_to_elasticsearch(aggregate_report, org_name = metadata["org_name"] report_id = metadata["report_id"] domain = aggregate_report["policy_published"]["domain"] - begin_date = human_timestamp_to_datetime(metadata["begin_date"], - to_utc=True) - end_date = human_timestamp_to_datetime(metadata["end_date"], - to_utc=True) + begin_date = human_timestamp_to_datetime(metadata["begin_date"], to_utc=True) + end_date = human_timestamp_to_datetime(metadata["end_date"], to_utc=True) begin_date_human = begin_date.strftime("%Y-%m-%d %H:%M:%SZ") end_date_human = end_date.strftime("%Y-%m-%d %H:%M:%SZ") if monthly_indexes: @@ -313,8 +315,7 @@ def save_aggregate_report_to_elasticsearch(aggregate_report, index_date = begin_date.strftime("%Y-%m-%d") aggregate_report["begin_date"] = begin_date aggregate_report["end_date"] = end_date - date_range = [aggregate_report["begin_date"], - aggregate_report["end_date"]] + date_range = [aggregate_report["begin_date"], aggregate_report["end_date"]] org_name_query = Q(dict(match_phrase=dict(org_name=org_name))) report_id_query = Q(dict(match_phrase=dict(report_id=report_id))) @@ -333,18 +334,20 @@ def save_aggregate_report_to_elasticsearch(aggregate_report, try: existing = search.execute() except Exception as error_: - raise ElasticsearchError("Elasticsearch's search for existing report \ - error: {}".format(error_.__str__())) + raise ElasticsearchError( + "Elasticsearch's search for existing report \ + error: {}".format( + error_.__str__() + ) + ) if len(existing) > 0: - raise AlreadySaved("An aggregate report ID {0} from {1} about {2} " - "with a date range of {3} UTC to {4} UTC already " - "exists in " - "Elasticsearch".format(report_id, - org_name, - domain, - begin_date_human, - end_date_human)) + raise AlreadySaved( + "An aggregate report ID {0} from {1} about {2} " + "with a date range of {3} UTC to {4} UTC already " + "exists in " + "Elasticsearch".format(report_id, org_name, domain, begin_date_human, end_date_human) + ) published_policy = _PublishedPolicy( domain=aggregate_report["policy_published"]["domain"], adkim=aggregate_report["policy_published"]["adkim"], @@ -352,7 +355,7 @@ def save_aggregate_report_to_elasticsearch(aggregate_report, p=aggregate_report["policy_published"]["p"], sp=aggregate_report["policy_published"]["sp"], pct=aggregate_report["policy_published"]["pct"], - fo=aggregate_report["policy_published"]["fo"] + fo=aggregate_report["policy_published"]["fo"], ) for record in aggregate_report["records"]: @@ -373,66 +376,69 @@ def save_aggregate_report_to_elasticsearch(aggregate_report, source_base_domain=record["source"]["base_domain"], message_count=record["count"], disposition=record["policy_evaluated"]["disposition"], - dkim_aligned=record["policy_evaluated"]["dkim"] is not None and - record["policy_evaluated"]["dkim"].lower() == "pass", - spf_aligned=record["policy_evaluated"]["spf"] is not None and - record["policy_evaluated"]["spf"].lower() == "pass", + dkim_aligned=record["policy_evaluated"]["dkim"] is not None + and record["policy_evaluated"]["dkim"].lower() == "pass", + spf_aligned=record["policy_evaluated"]["spf"] is not None + and record["policy_evaluated"]["spf"].lower() == "pass", header_from=record["identifiers"]["header_from"], envelope_from=record["identifiers"]["envelope_from"], - envelope_to=record["identifiers"]["envelope_to"] + envelope_to=record["identifiers"]["envelope_to"], ) for override in record["policy_evaluated"]["policy_override_reasons"]: - agg_doc.add_policy_override(type_=override["type"], - comment=override["comment"]) + agg_doc.add_policy_override(type_=override["type"], comment=override["comment"]) for dkim_result in record["auth_results"]["dkim"]: - agg_doc.add_dkim_result(domain=dkim_result["domain"], - selector=dkim_result["selector"], - result=dkim_result["result"]) + agg_doc.add_dkim_result( + domain=dkim_result["domain"], + selector=dkim_result["selector"], + result=dkim_result["result"], + ) for spf_result in record["auth_results"]["spf"]: - agg_doc.add_spf_result(domain=spf_result["domain"], - scope=spf_result["scope"], - result=spf_result["result"]) + agg_doc.add_spf_result( + domain=spf_result["domain"], scope=spf_result["scope"], result=spf_result["result"] + ) index = "dmarc_aggregate" if index_suffix: index = "{0}_{1}".format(index, index_suffix) index = "{0}-{1}".format(index, index_date) - index_settings = dict(number_of_shards=number_of_shards, - number_of_replicas=number_of_replicas) + index_settings = dict( + number_of_shards=number_of_shards, number_of_replicas=number_of_replicas + ) create_indexes([index], index_settings) agg_doc.meta.index = index try: agg_doc.save() except Exception as e: - raise ElasticsearchError( - "Elasticsearch error: {0}".format(e.__str__())) + raise ElasticsearchError("Elasticsearch error: {0}".format(e.__str__())) -def save_forensic_report_to_elasticsearch(forensic_report, - index_suffix=None, - monthly_indexes=False, - number_of_shards=1, - number_of_replicas=0): +def save_forensic_report_to_elasticsearch( + forensic_report, + index_suffix=None, + monthly_indexes=False, + number_of_shards=1, + number_of_replicas=0, +): """ - Saves a parsed DMARC forensic report to ElasticSearch - - Args: - forensic_report (OrderedDict): A parsed forensic report - index_suffix (str): The suffix of the name of the index to save to - monthly_indexes (bool): Use monthly indexes instead of daily - indexes - number_of_shards (int): The number of shards to use in the index - number_of_replicas (int): The number of replicas to use in the - index - - Raises: - AlreadySaved + Saves a parsed DMARC forensic report to ElasticSearch + + Args: + forensic_report (OrderedDict): A parsed forensic report + index_suffix (str): The suffix of the name of the index to save to + monthly_indexes (bool): Use monthly indexes instead of daily + indexes + number_of_shards (int): The number of shards to use in the index + number_of_replicas (int): The number of replicas to use in the + index - """ + Raises: + AlreadySaved + + """ logger.info("Saving forensic report to Elasticsearch") forensic_report = forensic_report.copy() sample_date = None @@ -474,14 +480,12 @@ def save_forensic_report_to_elasticsearch(forensic_report, existing = search.execute() if len(existing) > 0: - raise AlreadySaved("A forensic sample to {0} from {1} " - "with a subject of {2} and arrival date of {3} " - "already exists in " - "Elasticsearch".format(to_, - from_, - subject, - arrival_date_human - )) + raise AlreadySaved( + "A forensic sample to {0} from {1} " + "with a subject of {2} and arrival date of {3} " + "already exists in " + "Elasticsearch".format(to_, from_, subject, arrival_date_human) + ) parsed_sample = forensic_report["parsed_sample"] sample = _ForensicSampleDoc( @@ -491,25 +495,23 @@ def save_forensic_report_to_elasticsearch(forensic_report, date=sample_date, subject=forensic_report["parsed_sample"]["subject"], filename_safe_subject=parsed_sample["filename_safe_subject"], - body=forensic_report["parsed_sample"]["body"] + body=forensic_report["parsed_sample"]["body"], ) for address in forensic_report["parsed_sample"]["to"]: - sample.add_to(display_name=address["display_name"], - address=address["address"]) + sample.add_to(display_name=address["display_name"], address=address["address"]) for address in forensic_report["parsed_sample"]["reply_to"]: - sample.add_reply_to(display_name=address["display_name"], - address=address["address"]) + sample.add_reply_to(display_name=address["display_name"], address=address["address"]) for address in forensic_report["parsed_sample"]["cc"]: - sample.add_cc(display_name=address["display_name"], - address=address["address"]) + sample.add_cc(display_name=address["display_name"], address=address["address"]) for address in forensic_report["parsed_sample"]["bcc"]: - sample.add_bcc(display_name=address["display_name"], - address=address["address"]) + sample.add_bcc(display_name=address["display_name"], address=address["address"]) for attachment in forensic_report["parsed_sample"]["attachments"]: - sample.add_attachment(filename=attachment["filename"], - content_type=attachment["mail_content_type"], - sha256=attachment["sha256"]) + sample.add_attachment( + filename=attachment["filename"], + content_type=attachment["mail_content_type"], + sha256=attachment["sha256"], + ) try: forensic_doc = _ForensicReportDoc( feedback_type=forensic_report["feedback_type"], @@ -525,12 +527,11 @@ def save_forensic_report_to_elasticsearch(forensic_report, source_country=forensic_report["source"]["country"], source_reverse_dns=forensic_report["source"]["reverse_dns"], source_base_domain=forensic_report["source"]["base_domain"], - authentication_mechanisms=forensic_report[ - "authentication_mechanisms"], + authentication_mechanisms=forensic_report["authentication_mechanisms"], auth_failure=forensic_report["auth_failure"], dkim_domain=forensic_report["dkim_domain"], original_rcpt_to=forensic_report["original_rcpt_to"], - sample=sample + sample=sample, ) index = "dmarc_forensic" @@ -541,15 +542,16 @@ def save_forensic_report_to_elasticsearch(forensic_report, else: index_date = arrival_date.strftime("%Y-%m-%d") index = "{0}-{1}".format(index, index_date) - index_settings = dict(number_of_shards=number_of_shards, - number_of_replicas=number_of_replicas) + index_settings = dict( + number_of_shards=number_of_shards, number_of_replicas=number_of_replicas + ) create_indexes([index], index_settings) forensic_doc.meta.index = index try: forensic_doc.save() except Exception as e: - raise ElasticsearchError( - "Elasticsearch error: {0}".format(e.__str__())) + raise ElasticsearchError("Elasticsearch error: {0}".format(e.__str__())) except KeyError as e: raise InvalidForensicReport( - "Forensic report missing required field: {0}".format(e.__str__())) + "Forensic report missing required field: {0}".format(e.__str__()) + ) diff --git a/parsedmarc/kafkaclient.py b/parsedmarc/kafkaclient.py index 02bf833a..53feeb51 100644 --- a/parsedmarc/kafkaclient.py +++ b/parsedmarc/kafkaclient.py @@ -17,8 +17,7 @@ class KafkaError(RuntimeError): class KafkaClient(object): - def __init__(self, kafka_hosts, ssl=False, username=None, - password=None, ssl_context=None): + def __init__(self, kafka_hosts, ssl=False, username=None, password=None, ssl_context=None): """ Initializes the Kafka client Args: @@ -37,10 +36,11 @@ def __init__(self, kafka_hosts, ssl=False, username=None, ``$ConnectionString``, and the password is the Azure Event Hub connection string. """ - config = dict(value_serializer=lambda v: json.dumps(v).encode( - 'utf-8'), - bootstrap_servers=kafka_hosts, - client_id="parsedmarc-{0}".format(__version__)) + config = dict( + value_serializer=lambda v: json.dumps(v).encode("utf-8"), + bootstrap_servers=kafka_hosts, + client_id="parsedmarc-{0}".format(__version__), + ) if ssl or username or password: config["security_protocol"] = "SSL" config["ssl_context"] = ssl_context or create_default_context() @@ -55,14 +55,14 @@ def __init__(self, kafka_hosts, ssl=False, username=None, @staticmethod def strip_metadata(report): """ - Duplicates org_name, org_email and report_id into JSON root - and removes report_metadata key to bring it more inline - with Elastic output. + Duplicates org_name, org_email and report_id into JSON root + and removes report_metadata key to bring it more inline + with Elastic output. """ - report['org_name'] = report['report_metadata']['org_name'] - report['org_email'] = report['report_metadata']['org_email'] - report['report_id'] = report['report_metadata']['report_id'] - report.pop('report_metadata') + report["org_name"] = report["report_metadata"]["org_name"] + report["org_email"] = report["report_metadata"]["org_email"] + report["report_id"] = report["report_metadata"]["report_id"] + report.pop("report_metadata") return report @@ -80,13 +80,11 @@ def generate_daterange(report): end_date = human_timestamp_to_datetime(metadata["end_date"]) begin_date_human = begin_date.strftime("%Y-%m-%dT%H:%M:%S") end_date_human = end_date.strftime("%Y-%m-%dT%H:%M:%S") - date_range = [begin_date_human, - end_date_human] + date_range = [begin_date_human, end_date_human] logger.debug("date_range is {}".format(date_range)) return date_range - def save_aggregate_reports_to_kafka(self, aggregate_reports, - aggregate_topic): + def save_aggregate_reports_to_kafka(self, aggregate_reports, aggregate_topic): """ Saves aggregate DMARC reports to Kafka @@ -96,38 +94,34 @@ def save_aggregate_reports_to_kafka(self, aggregate_reports, aggregate_topic (str): The name of the Kafka topic """ - if (isinstance(aggregate_reports, dict) or - isinstance(aggregate_reports, OrderedDict)): + if isinstance(aggregate_reports, dict) or isinstance(aggregate_reports, OrderedDict): aggregate_reports = [aggregate_reports] if len(aggregate_reports) < 1: return for report in aggregate_reports: - report['date_range'] = self.generate_daterange(report) + report["date_range"] = self.generate_daterange(report) report = self.strip_metadata(report) - for slice in report['records']: - slice['date_range'] = report['date_range'] - slice['org_name'] = report['org_name'] - slice['org_email'] = report['org_email'] - slice['policy_published'] = report['policy_published'] - slice['report_id'] = report['report_id'] + for slice in report["records"]: + slice["date_range"] = report["date_range"] + slice["org_name"] = report["org_name"] + slice["org_email"] = report["org_email"] + slice["policy_published"] = report["policy_published"] + slice["report_id"] = report["report_id"] logger.debug("Sending slice.") try: logger.debug("Saving aggregate report to Kafka") self.producer.send(aggregate_topic, slice) except UnknownTopicOrPartitionError: - raise KafkaError( - "Kafka error: Unknown topic or partition on broker") + raise KafkaError("Kafka error: Unknown topic or partition on broker") except Exception as e: - raise KafkaError( - "Kafka error: {0}".format(e.__str__())) + raise KafkaError("Kafka error: {0}".format(e.__str__())) try: self.producer.flush() except Exception as e: - raise KafkaError( - "Kafka error: {0}".format(e.__str__())) + raise KafkaError("Kafka error: {0}".format(e.__str__())) def save_forensic_reports_to_kafka(self, forensic_reports, forensic_topic): """ @@ -151,13 +145,10 @@ def save_forensic_reports_to_kafka(self, forensic_reports, forensic_topic): logger.debug("Saving forensic reports to Kafka") self.producer.send(forensic_topic, forensic_reports) except UnknownTopicOrPartitionError: - raise KafkaError( - "Kafka error: Unknown topic or partition on broker") + raise KafkaError("Kafka error: Unknown topic or partition on broker") except Exception as e: - raise KafkaError( - "Kafka error: {0}".format(e.__str__())) + raise KafkaError("Kafka error: {0}".format(e.__str__())) try: self.producer.flush() except Exception as e: - raise KafkaError( - "Kafka error: {0}".format(e.__str__())) + raise KafkaError("Kafka error: {0}".format(e.__str__())) diff --git a/parsedmarc/loganalytics.py b/parsedmarc/loganalytics.py index 78e018e6..15686921 100644 --- a/parsedmarc/loganalytics.py +++ b/parsedmarc/loganalytics.py @@ -1,163 +1,90 @@ -# -*- coding: utf-8 -*- -from parsedmarc.log import logger +from typing import List, Dict, Optional + from azure.core.exceptions import HttpResponseError from azure.identity import ClientSecretCredential from azure.monitor.ingestion import LogsIngestionClient +from parsedmarc.log import logger + class LogAnalyticsException(Exception): - """Raised when an Elasticsearch error occurs""" + """Errors originating from LogsIngestionClient""" -class LogAnalyticsConfig(): - """ - The LogAnalyticsConfig class is used to define the configuration - for the Log Analytics Client. +class LogAnalyticsClient(object): + """Azure Log Analytics Client + + Pushes the DMARC reports to Log Analytics via Data Collection Rules. - Properties: - client_id (str): - The client ID of the service principle. - client_secret (str): - The client secret of the service principle. - tenant_id (str): - The tenant ID where - the service principle resides. - dce (str): - The Data Collection Endpoint (DCE) - used by the Data Collection Rule (DCR). - dcr_immutable_id (str): - The immutable ID of - the Data Collection Rule (DCR). - dcr_aggregate_stream (str): - The Stream name where - the Aggregate DMARC reports - need to be pushed. - dcr_forensic_stream (str): - The Stream name where - the Forensic DMARC reports - need to be pushed. + References: + - https://learn.microsoft.com/en-us/azure/azure-monitor/logs/logs-ingestion-api-overview """ + def __init__( - self, - client_id: str, - client_secret: str, - tenant_id: str, - dce: str, - dcr_immutable_id: str, - dcr_aggregate_stream: str, - dcr_forensic_stream: str): + self, + client_id: str, + client_secret: str, + tenant_id: str, + dce: str, + dcr_immutable_id: str, + dcr_aggregate_stream: Optional[str] = None, + dcr_forensic_stream: Optional[str] = None, + ): + """ + Args: + client_id: The client ID of the service principle. + client_secret: The client secret of the service principle. + tenant_id: The tenant ID where the service principle resides. + dce: The Data Collection Endpoint (DCE) used by the Data Collection Rule (DCR). + dcr_immutable_id: The immutable ID of the Data Collection Rule (DCR). + dcr_aggregate_stream: The Stream name where the Aggregate DMARC reports need to be pushed. + dcr_forensic_stream: The Stream name where the Forensic DMARC reports need to be pushed. + """ self.client_id = client_id - self.client_secret = client_secret + self._client_secret = client_secret self.tenant_id = tenant_id self.dce = dce self.dcr_immutable_id = dcr_immutable_id self.dcr_aggregate_stream = dcr_aggregate_stream self.dcr_forensic_stream = dcr_forensic_stream - -class LogAnalyticsClient(object): - """ - The LogAnalyticsClient is used to push - the generated DMARC reports to Log Analytics - via Data Collection Rules. - """ - def __init__( - self, - client_id: str, - client_secret: str, - tenant_id: str, - dce: str, - dcr_immutable_id: str, - dcr_aggregate_stream: str, - dcr_forensic_stream: str): - self.conf = LogAnalyticsConfig( - client_id=client_id, - client_secret=client_secret, - tenant_id=tenant_id, - dce=dce, - dcr_immutable_id=dcr_immutable_id, - dcr_aggregate_stream=dcr_aggregate_stream, - dcr_forensic_stream=dcr_forensic_stream + self._credential = ClientSecretCredential( + tenant_id=tenant_id, client_id=client_id, client_secret=client_secret ) - if ( - not self.conf.client_id or - not self.conf.client_secret or - not self.conf.tenant_id or - not self.conf.dce or - not self.conf.dcr_immutable_id): - raise LogAnalyticsException( - "Invalid configuration. " + - "One or more required settings are missing.") + self.logs_client = LogsIngestionClient(dce, credential=self._credential) + return - def publish_json( - self, - results, - logs_client: LogsIngestionClient, - dcr_stream: str): - """ - Background function to publish given - DMARC reprot to specific Data Collection Rule. + def _publish_json(self, reports: List[Dict], dcr_stream: str) -> None: + """Publish DMARC reports to the given Data Collection Rule. Args: - results (list): - The results generated by parsedmarc. - logs_client (LogsIngestionClient): - The client used to send the DMARC reports. - dcr_stream (str): - The stream name where the DMARC reports needs to be pushed. + results: The results generated by parsedmarc. + logs_client: The client used to send the DMARC reports. + dcr_stream: The stream name where the DMARC reports needs to be pushed. """ try: - logs_client.upload(self.conf.dcr_immutable_id, dcr_stream, results) # type: ignore[attr-defined] + self.logs_client.upload(self.dcr_immutable_id, dcr_stream, reports) # type: ignore[attr-defined] except HttpResponseError as e: - raise LogAnalyticsException( - "Upload failed: {error}" - .format(error=e)) + raise LogAnalyticsException(f"Upload failed: {e!r}") + return def publish_results( - self, - results, - save_aggregate: bool, - save_forensic: bool): - """ - Function to publish DMARC reports to Log Analytics - via Data Collection Rules (DCR). - Look below for docs: - https://learn.microsoft.com/en-us/azure/azure-monitor/logs/logs-ingestion-api-overview + self, results: Dict[str, List[Dict]], save_aggregate: bool, save_forensic: bool + ) -> None: + """Publish DMARC reports to Log Analytics via Data Collection Rules (DCR). Args: - results (list): - The DMARC reports (Aggregate & Forensic) - save_aggregate (bool): - Whether Aggregate reports can be saved into Log Analytics - save_forensic (bool): - Whether Forensic reports can be saved into Log Analytics + results: The DMARC reports (Aggregate & Forensic) + save_aggregate: Whether Aggregate reports can be saved into Log Analytics + save_forensic: Whether Forensic reports can be saved into Log Analytics """ - conf = self.conf - credential = ClientSecretCredential( - tenant_id=conf.tenant_id, - client_id=conf.client_id, - client_secret=conf.client_secret - ) - logs_client = LogsIngestionClient(conf.dce, credential=credential) - if ( - results['aggregate_reports'] and - conf.dcr_aggregate_stream and - len(results['aggregate_reports']) > 0 and - save_aggregate): + if results["aggregate_reports"] and self.dcr_aggregate_stream and save_aggregate: logger.info("Publishing aggregate reports.") - self.publish_json( - results['aggregate_reports'], - logs_client, - conf.dcr_aggregate_stream) + self._publish_json(results["aggregate_reports"], self.dcr_aggregate_stream) logger.info("Successfully pushed aggregate reports.") - if ( - results['forensic_reports'] and - conf.dcr_forensic_stream and - len(results['forensic_reports']) > 0 and - save_forensic): + + if results["forensic_reports"] and self.dcr_forensic_stream and save_forensic: logger.info("Publishing forensic reports.") - self.publish_json( - results['forensic_reports'], - logs_client, - conf.dcr_forensic_stream) + self._publish_json(results["forensic_reports"], self.dcr_forensic_stream) logger.info("Successfully pushed forensic reports.") + return diff --git a/parsedmarc/mail/__init__.py b/parsedmarc/mail/__init__.py index df3c4f2c..5d40b3ad 100644 --- a/parsedmarc/mail/__init__.py +++ b/parsedmarc/mail/__init__.py @@ -3,7 +3,4 @@ from parsedmarc.mail.gmail import GmailConnection from parsedmarc.mail.imap import IMAPConnection -__all__ = ["MailboxConnection", - "MSGraphConnection", - "GmailConnection", - "IMAPConnection"] +__all__ = ["MailboxConnection", "MSGraphConnection", "GmailConnection", "IMAPConnection"] diff --git a/parsedmarc/mail/gmail.py b/parsedmarc/mail/gmail.py index 40f61af4..09bdbd1c 100644 --- a/parsedmarc/mail/gmail.py +++ b/parsedmarc/mail/gmail.py @@ -29,17 +29,20 @@ def _get_creds(token_file, credentials_file, scopes, oauth2_port): if creds and creds.expired and creds.refresh_token: creds.refresh(Request()) else: - flow = InstalledAppFlow.from_client_secrets_file( - credentials_file, scopes) - creds = flow.run_local_server(open_browser=False, - oauth2_port=oauth2_port) + flow = InstalledAppFlow.from_client_secrets_file(credentials_file, scopes) + creds = flow.run_local_server(open_browser=False, oauth2_port=oauth2_port) # Save the credentials for the next run - with Path(token_file).open('w') as token: + with Path(token_file).open("w") as token: token.write(creds.to_json()) return creds class GmailConnection(MailboxConnection): + """MailboxConnection for Google accounts using the Google API. + + This will support both Gmail and Google Workspace accounts. + """ + def __init__( self, token_file: str, @@ -50,64 +53,59 @@ def __init__( oauth2_port: int, ): creds = _get_creds(token_file, credentials_file, scopes, oauth2_port) - self.service = build('gmail', 'v1', credentials=creds) + self.service = build("gmail", "v1", credentials=creds) self.include_spam_trash = include_spam_trash self.reports_label_id = self._find_label_id_for_label(reports_folder) def create_folder(self, folder_name: str) -> None: # Gmail doesn't support the name Archive - if folder_name == 'Archive': + if folder_name == "Archive": return logger.debug(f"Creating label {folder_name}") - request_body: "Label" = {'name': folder_name, 'messageListVisibility': 'show'} + request_body: "Label" = {"name": folder_name, "messageListVisibility": "show"} try: - self.service.users().labels()\ - .create(userId='me', body=request_body).execute() + self.service.users().labels().create(userId="me", body=request_body).execute() except HttpError as e: if e.status_code == 409: - logger.debug(f'Folder {folder_name} already exists, ' - f'skipping creation') + logger.debug(f"Folder {folder_name} already exists, " f"skipping creation") else: raise e return def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: reports_label_id = self._find_label_id_for_label(reports_folder) - results = self.service.users().messages()\ - .list(userId='me', - includeSpamTrash=self.include_spam_trash, - labelIds=[reports_label_id] - )\ + results = ( + self.service.users() + .messages() + .list( + userId="me", includeSpamTrash=self.include_spam_trash, labelIds=[reports_label_id] + ) .execute() - messages = results.get('messages', []) - return [message['id'] for message in messages] + ) + messages = results.get("messages", []) + return [message["id"] for message in messages] def fetch_message(self, message_id: str) -> bytes: - msg = self.service.users().messages()\ - .get(userId='me', - id=message_id, - format="raw" - )\ - .execute() - return urlsafe_b64decode(msg['raw']) + msg = ( + self.service.users().messages().get(userId="me", id=message_id, format="raw").execute() + ) + return urlsafe_b64decode(msg["raw"]) def delete_message(self, message_id: str) -> None: - self.service.users().messages().delete(userId='me', id=message_id) + self.service.users().messages().delete(userId="me", id=message_id) return def move_message(self, message_id: str, folder_name: str): label_id = self._find_label_id_for_label(folder_name) logger.debug(f"Moving message UID {message_id} to {folder_name}") request_body: "ModifyMessageRequest" = { - 'addLabelIds': [label_id], - 'removeLabelIds': [self.reports_label_id] + "addLabelIds": [label_id], + "removeLabelIds": [self.reports_label_id], } - self.service.users().messages()\ - .modify(userId='me', - id=message_id, - body=request_body)\ - .execute() + self.service.users().messages().modify( + userId="me", id=message_id, body=request_body + ).execute() return def keepalive(self) -> None: @@ -115,7 +113,7 @@ def keepalive(self) -> None: return def watch(self, check_callback, check_timeout) -> None: - """ Checks the mailbox for new messages every n seconds""" + """Checks the mailbox for new messages every n seconds""" while True: sleep(check_timeout) check_callback(self) @@ -123,9 +121,9 @@ def watch(self, check_callback, check_timeout) -> None: @lru_cache(maxsize=10) def _find_label_id_for_label(self, label_name: str) -> str: - results = self.service.users().labels().list(userId='me').execute() - labels = results.get('labels', []) + results = self.service.users().labels().list(userId="me").execute() + labels = results.get("labels", []) for label in labels: - if label_name == label['id'] or label_name == label['name']: - return label['id'] + if label_name == label["id"] or label_name == label["name"]: + return label["id"] raise ValueError(f"Label {label_name} not found") diff --git a/parsedmarc/mail/graph.py b/parsedmarc/mail/graph.py index 4387138a..07a383e5 100644 --- a/parsedmarc/mail/graph.py +++ b/parsedmarc/mail/graph.py @@ -4,9 +4,13 @@ from time import sleep from typing import Dict, List, Optional, Union, Any -from azure.identity import UsernamePasswordCredential, \ - DeviceCodeCredential, ClientSecretCredential, \ - TokenCachePersistenceOptions, AuthenticationRecord +from azure.identity import ( + UsernamePasswordCredential, + DeviceCodeCredential, + ClientSecretCredential, + TokenCachePersistenceOptions, + AuthenticationRecord, +) from msgraph.core import GraphClient from parsedmarc.log import logger @@ -18,20 +22,20 @@ class AuthMethod(Enum): UsernamePassword = 2 ClientSecret = 3 + Credential = Union[DeviceCodeCredential, UsernamePasswordCredential, ClientSecretCredential] + def _get_cache_args(token_path: Path, allow_unencrypted_storage: bool): cache_args: Dict[str, Any] = { - 'cache_persistence_options': - TokenCachePersistenceOptions( - name='parsedmarc', - allow_unencrypted_storage=allow_unencrypted_storage, - ) + "cache_persistence_options": TokenCachePersistenceOptions( + name="parsedmarc", + allow_unencrypted_storage=allow_unencrypted_storage, + ) } auth_record = _load_token(token_path) if auth_record: - cache_args['authentication_record'] = \ - AuthenticationRecord.deserialize(auth_record) + cache_args["authentication_record"] = AuthenticationRecord.deserialize(auth_record) return cache_args @@ -44,7 +48,7 @@ def _load_token(token_path: Path) -> Optional[str]: def _cache_auth_record(record: AuthenticationRecord, token_path: Path): token = record.serialize() - with token_path.open('w') as token_file: + with token_path.open("w") as token_file: token_file.write(token) @@ -52,50 +56,54 @@ def _generate_credential(auth_method: str, token_path: Path, **kwargs) -> Creden credential: Credential if auth_method == AuthMethod.DeviceCode.name: credential = DeviceCodeCredential( - client_id=kwargs['client_id'], - client_secret=kwargs['client_secret'], + client_id=kwargs["client_id"], + client_secret=kwargs["client_secret"], disable_automatic_authentication=True, - tenant_id=kwargs['tenant_id'], + tenant_id=kwargs["tenant_id"], **_get_cache_args( - token_path, - allow_unencrypted_storage=kwargs['allow_unencrypted_storage']) + token_path, allow_unencrypted_storage=kwargs["allow_unencrypted_storage"] + ), ) return credential if auth_method == AuthMethod.UsernamePassword.name: credential = UsernamePasswordCredential( - client_id=kwargs['client_id'], - client_credential=kwargs['client_secret'], + client_id=kwargs["client_id"], + client_credential=kwargs["client_secret"], disable_automatic_authentication=True, - username=kwargs['username'], - password=kwargs['password'], + username=kwargs["username"], + password=kwargs["password"], **_get_cache_args( - token_path, - allow_unencrypted_storage=kwargs['allow_unencrypted_storage']) + token_path, allow_unencrypted_storage=kwargs["allow_unencrypted_storage"] + ), ) return credential if auth_method == AuthMethod.ClientSecret.name: credential = ClientSecretCredential( - client_id=kwargs['client_id'], - tenant_id=kwargs['tenant_id'], - client_secret=kwargs['client_secret'] + client_id=kwargs["client_id"], + tenant_id=kwargs["tenant_id"], + client_secret=kwargs["client_secret"], ) return credential - raise RuntimeError(f'Auth method {auth_method} not found') + raise RuntimeError(f"Auth method {auth_method} not found") class MSGraphConnection(MailboxConnection): - def __init__(self, - auth_method: str, - mailbox: str, - client_id: str, - client_secret: str, - username: str, - password: str, - tenant_id: str, - token_file: str, - allow_unencrypted_storage: bool): + """MailboxConnection to a Microsoft account via the Micorsoft Graph API""" + + def __init__( + self, + auth_method: str, + mailbox: str, + client_id: str, + client_secret: str, + username: str, + password: str, + tenant_id: str, + token_file: str, + allow_unencrypted_storage: bool, + ): token_path = Path(token_file) credential = _generate_credential( auth_method, @@ -105,156 +113,139 @@ def __init__(self, password=password, tenant_id=tenant_id, token_path=token_path, - allow_unencrypted_storage=allow_unencrypted_storage) - client_params: Dict[str, Any] = { - 'credential': credential - } + allow_unencrypted_storage=allow_unencrypted_storage, + ) + client_params: Dict[str, Any] = {"credential": credential} if isinstance(credential, (DeviceCodeCredential, UsernamePasswordCredential)): - scopes = ['Mail.ReadWrite'] + scopes = ["Mail.ReadWrite"] # Detect if mailbox is shared if mailbox and username != mailbox: - scopes = ['Mail.ReadWrite.Shared'] + scopes = ["Mail.ReadWrite.Shared"] auth_record = credential.authenticate(scopes=scopes) _cache_auth_record(auth_record, token_path) - client_params['scopes'] = scopes + client_params["scopes"] = scopes self._client = GraphClient(**client_params) self.mailbox_name = mailbox def create_folder(self, folder_name: str) -> None: - sub_url = '' - path_parts = folder_name.split('/') + sub_url = "" + path_parts = folder_name.split("/") if len(path_parts) > 1: # Folder is a subFolder parent_folder_id = None for folder in path_parts[:-1]: - parent_folder_id = self._find_folder_id_with_parent( - folder, parent_folder_id) - sub_url = f'/{parent_folder_id}/childFolders' + parent_folder_id = self._find_folder_id_with_parent(folder, parent_folder_id) + sub_url = f"/{parent_folder_id}/childFolders" folder_name = path_parts[-1] - request_body = { - 'displayName': folder_name - } - request_url = f'/users/{self.mailbox_name}/mailFolders{sub_url}' + request_body = {"displayName": folder_name} + request_url = f"/users/{self.mailbox_name}/mailFolders{sub_url}" resp = self._client.post(request_url, json=request_body) if resp.status_code == 409: - logger.debug(f'Folder {folder_name} already exists, ' - f'skipping creation') + logger.debug(f"Folder {folder_name} already exists, " f"skipping creation") elif resp.status_code == 201: - logger.debug(f'Created folder {folder_name}') + logger.debug(f"Created folder {folder_name}") else: - logger.warning(f'Unknown response ' - f'{resp.status_code} {resp.json()}') + logger.warning(f"Unknown response " f"{resp.status_code} {resp.json()}") def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: - """ Returns a list of message UIDs in the specified folder """ + """Returns a list of message UIDs in the specified folder""" folder_id = self._find_folder_id_from_folder_path(reports_folder) - url = f'/users/{self.mailbox_name}/mailFolders/' \ - f'{folder_id}/messages' - batch_size = kwargs.get('batch_size') + url = f"/users/{self.mailbox_name}/mailFolders/" f"{folder_id}/messages" + batch_size = kwargs.get("batch_size") if not batch_size: batch_size = 0 emails = self._get_all_messages(url, batch_size) - return [email['id'] for email in emails] + return [email["id"] for email in emails] def _get_all_messages(self, url: str, batch_size: int): messages: list - params: Dict[str, Any] = { - '$select': 'id' - } + params: Dict[str, Any] = {"$select": "id"} if batch_size and batch_size > 0: - params['$top'] = batch_size + params["$top"] = batch_size else: - params['$top'] = 100 + params["$top"] = 100 result = self._client.get(url, params=params) if result.status_code != 200: - raise RuntimeError(f'Failed to fetch messages {result.text}') - messages = result.json()['value'] + raise RuntimeError(f"Failed to fetch messages {result.text}") + messages = result.json()["value"] # Loop if next page is present and not obtained message limit. - while '@odata.nextLink' in result.json() and ( - batch_size == 0 or - batch_size - len(messages) > 0): - result = self._client.get(result.json()['@odata.nextLink']) + while "@odata.nextLink" in result.json() and ( + batch_size == 0 or batch_size - len(messages) > 0 + ): + result = self._client.get(result.json()["@odata.nextLink"]) if result.status_code != 200: - raise RuntimeError(f'Failed to fetch messages {result.text}') - messages.extend(result.json()['value']) + raise RuntimeError(f"Failed to fetch messages {result.text}") + messages.extend(result.json()["value"]) return messages def mark_message_read(self, message_id: str): """Marks a message as read""" - url = f'/users/{self.mailbox_name}/messages/{message_id}' + url = f"/users/{self.mailbox_name}/messages/{message_id}" resp = self._client.patch(url, json={"isRead": "true"}) if resp.status_code != 200: - raise RuntimeWarning(f"Failed to mark message read" - f"{resp.status_code}: {resp.json()}") + raise RuntimeWarning( + f"Failed to mark message read" f"{resp.status_code}: {resp.json()}" + ) def fetch_message(self, message_id: str) -> str: - url = f'/users/{self.mailbox_name}/messages/{message_id}/$value' + url = f"/users/{self.mailbox_name}/messages/{message_id}/$value" result = self._client.get(url) if result.status_code != 200: - raise RuntimeWarning(f"Failed to fetch message" - f"{result.status_code}: {result.json()}") + raise RuntimeWarning( + f"Failed to fetch message" f"{result.status_code}: {result.json()}" + ) self.mark_message_read(message_id) return result.text def delete_message(self, message_id: str): - url = f'/users/{self.mailbox_name}/messages/{message_id}' + url = f"/users/{self.mailbox_name}/messages/{message_id}" resp = self._client.delete(url) if resp.status_code != 204: - raise RuntimeWarning(f"Failed to delete message " - f"{resp.status_code}: {resp.json()}") + raise RuntimeWarning(f"Failed to delete message " f"{resp.status_code}: {resp.json()}") def move_message(self, message_id: str, folder_name: str): folder_id = self._find_folder_id_from_folder_path(folder_name) - request_body = { - 'destinationId': folder_id - } - url = f'/users/{self.mailbox_name}/messages/{message_id}/move' + request_body = {"destinationId": folder_id} + url = f"/users/{self.mailbox_name}/messages/{message_id}/move" resp = self._client.post(url, json=request_body) if resp.status_code != 201: - raise RuntimeWarning(f"Failed to move message " - f"{resp.status_code}: {resp.json()}") + raise RuntimeWarning(f"Failed to move message " f"{resp.status_code}: {resp.json()}") def keepalive(self): # Not needed pass def watch(self, check_callback, check_timeout): - """ Checks the mailbox for new messages every n seconds""" + """Checks the mailbox for new messages every n seconds""" while True: sleep(check_timeout) check_callback(self) @lru_cache(maxsize=10) def _find_folder_id_from_folder_path(self, folder_name: str) -> str: - path_parts = folder_name.split('/') + path_parts = folder_name.split("/") parent_folder_id = None if len(path_parts) > 1: for folder in path_parts[:-1]: - folder_id = self._find_folder_id_with_parent( - folder, parent_folder_id) + folder_id = self._find_folder_id_with_parent(folder, parent_folder_id) parent_folder_id = folder_id - return self._find_folder_id_with_parent( - path_parts[-1], parent_folder_id) + return self._find_folder_id_with_parent(path_parts[-1], parent_folder_id) else: return self._find_folder_id_with_parent(folder_name, None) - def _find_folder_id_with_parent(self, - folder_name: str, - parent_folder_id: Optional[str]): - sub_url = '' + def _find_folder_id_with_parent(self, folder_name: str, parent_folder_id: Optional[str]): + sub_url = "" if parent_folder_id is not None: - sub_url = f'/{parent_folder_id}/childFolders' - url = f'/users/{self.mailbox_name}/mailFolders{sub_url}' + sub_url = f"/{parent_folder_id}/childFolders" + url = f"/users/{self.mailbox_name}/mailFolders{sub_url}" filter = f"?$filter=displayName eq '{folder_name}'" folders_resp = self._client.get(url + filter) if folders_resp.status_code != 200: - raise RuntimeWarning(f"Failed to list folders." - f"{folders_resp.json()}") - folders: list = folders_resp.json()['value'] - matched_folders = [folder for folder in folders - if folder['displayName'] == folder_name] + raise RuntimeWarning(f"Failed to list folders." f"{folders_resp.json()}") + folders: list = folders_resp.json()["value"] + matched_folders = [folder for folder in folders if folder["displayName"] == folder_name] if len(matched_folders) == 0: raise RuntimeError(f"folder {folder_name} not found") selected_folder = matched_folders[0] - return selected_folder['id'] + return selected_folder["id"] diff --git a/parsedmarc/mail/imap.py b/parsedmarc/mail/imap.py index 150185e1..81901a03 100644 --- a/parsedmarc/mail/imap.py +++ b/parsedmarc/mail/imap.py @@ -9,22 +9,32 @@ class IMAPConnection(MailboxConnection): - def __init__(self, - host=None, - user=None, - password=None, - port=None, - ssl=True, - verify=True, - timeout=30, - max_retries=4): + """MailboxConnection for connecting to a mailbox via IMAP.""" + + def __init__( + self, + host=None, + user=None, + password=None, + port=None, + ssl=True, + verify=True, + timeout=30, + max_retries=4, + ): self._username = user self._password = password self._verify = verify - self._client = IMAPClient(host, user, password, port=port, - ssl=ssl, verify=verify, - timeout=timeout, - max_retries=max_retries) + self._client = IMAPClient( + host, + user, + password, + port=port, + ssl=ssl, + verify=verify, + timeout=timeout, + max_retries=max_retries, + ) def create_folder(self, folder_name: str): self._client.create_folder(folder_name) @@ -46,10 +56,7 @@ def keepalive(self): self._client.noop() def watch(self, check_callback, check_timeout): - """ - Use an IDLE IMAP connection to parse incoming emails, - and pass the results to a callback function - """ + """Use an IDLE IMAP connection to parse incoming emails, and pass the results to a callback function""" # IDLE callback sends IMAPClient object, # send back the imap connection object instead @@ -59,18 +66,19 @@ def idle_callback_wrapper(client: IMAPClient): while True: try: - IMAPClient(host=self._client.host, - username=self._username, - password=self._password, - port=self._client.port, - ssl=self._client.ssl, - verify=self._verify, - idle_callback=idle_callback_wrapper, - idle_timeout=check_timeout) + IMAPClient( + host=self._client.host, + username=self._username, + password=self._password, + port=self._client.port, + ssl=self._client.ssl, + verify=self._verify, + idle_callback=idle_callback_wrapper, + idle_timeout=check_timeout, + ) except (timeout, IMAPClientError): logger.warning("IMAP connection timeout. Reconnecting...") sleep(check_timeout) except Exception as e: - logger.warning("IMAP connection error. {0}. " - "Reconnecting...".format(e)) + logger.warning("IMAP connection error. {0}. " "Reconnecting...".format(e)) sleep(check_timeout) diff --git a/parsedmarc/mail/mailbox_connection.py b/parsedmarc/mail/mailbox_connection.py index e11fa379..4f5d1b19 100644 --- a/parsedmarc/mail/mailbox_connection.py +++ b/parsedmarc/mail/mailbox_connection.py @@ -3,9 +3,8 @@ class MailboxConnection(ABC): - """ - Interface for a mailbox connection - """ + """Interface for a mailbox connection""" + def create_folder(self, folder_name: str) -> None: raise NotImplementedError diff --git a/parsedmarc/s3.py b/parsedmarc/s3.py index e8269abc..0b0fd2b5 100644 --- a/parsedmarc/s3.py +++ b/parsedmarc/s3.py @@ -10,8 +10,9 @@ class S3Client(object): """A client for a Amazon S3""" - def __init__(self, bucket_name, bucket_path, region_name, endpoint_url, - access_key_id, secret_access_key): + def __init__( + self, bucket_name, bucket_path, region_name, endpoint_url, access_key_id, secret_access_key + ): """ Initializes the S3Client Args: @@ -34,7 +35,7 @@ def __init__(self, bucket_name, bucket_path, region_name, endpoint_url, # https://github.com/boto/boto3/blob/1.24.7/boto3/session.py#L312 self.s3 = boto3.resource( - 's3', + "s3", region_name=region_name, endpoint_url=endpoint_url, aws_access_key_id=access_key_id, @@ -43,15 +44,13 @@ def __init__(self, bucket_name, bucket_path, region_name, endpoint_url, self.bucket = self.s3.Bucket(self.bucket_name) def save_aggregate_report_to_s3(self, report): - self.save_report_to_s3(report, 'aggregate') + self.save_report_to_s3(report, "aggregate") def save_forensic_report_to_s3(self, report): - self.save_report_to_s3(report, 'forensic') + self.save_report_to_s3(report, "forensic") def save_report_to_s3(self, report, report_type): - report_date = human_timestamp_to_datetime( - report["report_metadata"]["begin_date"] - ) + report_date = human_timestamp_to_datetime(report["report_metadata"]["begin_date"]) report_id = report["report_metadata"]["report_id"] path_template = "{0}/{1}/year={2}/month={3:02d}/day={4:02d}/{5}.json" object_path = path_template.format( @@ -60,19 +59,12 @@ def save_report_to_s3(self, report, report_type): report_date.year, report_date.month, report_date.day, - report_id + report_id, + ) + logger.debug( + "Saving {0} report to s3://{1}/{2}".format(report_type, self.bucket_name, object_path) ) - logger.debug("Saving {0} report to s3://{1}/{2}".format( - report_type, - self.bucket_name, - object_path)) object_metadata = { - k: v - for k, v in report["report_metadata"].items() - if k in self.metadata_keys + k: v for k, v in report["report_metadata"].items() if k in self.metadata_keys } - self.bucket.put_object( - Body=json.dumps(report), - Key=object_path, - Metadata=object_metadata - ) + self.bucket.put_object(Body=json.dumps(report), Key=object_path, Metadata=object_metadata) diff --git a/parsedmarc/splunk.py b/parsedmarc/splunk.py index 36285238..9b316e23 100644 --- a/parsedmarc/splunk.py +++ b/parsedmarc/splunk.py @@ -22,8 +22,7 @@ class HECClient(object): # http://docs.splunk.com/Documentation/Splunk/latest/Data/AboutHEC # http://docs.splunk.com/Documentation/Splunk/latest/RESTREF/RESTinput#services.2Fcollector - def __init__(self, url, access_token, index, - source="parsedmarc", verify=True, timeout=60): + def __init__(self, url, access_token, index, source="parsedmarc", verify=True, timeout=60): """ Initializes the HECClient @@ -37,8 +36,7 @@ def __init__(self, url, access_token, index, data before giving up """ url = urlparse(url) - self.url = "{0}://{1}/services/collector/event/1.0".format(url.scheme, - url.netloc) + self.url = "{0}://{1}/services/collector/event/1.0".format(url.scheme, url.netloc) self.access_token = access_token.lstrip("Splunk ") self.index = index self.host = socket.getfqdn() @@ -46,12 +44,11 @@ def __init__(self, url, access_token, index, self.session = requests.Session() self.timeout = timeout self.session.verify = verify - self._common_data = dict(host=self.host, source=self.source, - index=self.index) + self._common_data = dict(host=self.host, source=self.source, index=self.index) self.session.headers = { "User-Agent": "parsedmarc/{0}".format(__version__), - "Authorization": "Splunk {0}".format(self.access_token) + "Authorization": "Splunk {0}".format(self.access_token), } def save_aggregate_reports_to_splunk(self, aggregate_reports): @@ -78,34 +75,24 @@ def save_aggregate_reports_to_splunk(self, aggregate_reports): for metadata in report["report_metadata"]: new_report[metadata] = report["report_metadata"][metadata] new_report["published_policy"] = report["policy_published"] - new_report["source_ip_address"] = record["source"][ - "ip_address"] + new_report["source_ip_address"] = record["source"]["ip_address"] new_report["source_country"] = record["source"]["country"] - new_report["source_reverse_dns"] = record["source"][ - "reverse_dns"] - new_report["source_base_domain"] = record["source"][ - "base_domain"] + new_report["source_reverse_dns"] = record["source"]["reverse_dns"] + new_report["source_base_domain"] = record["source"]["base_domain"] new_report["message_count"] = record["count"] - new_report["disposition"] = record["policy_evaluated"][ - "disposition" - ] + new_report["disposition"] = record["policy_evaluated"]["disposition"] new_report["spf_aligned"] = record["alignment"]["spf"] new_report["dkim_aligned"] = record["alignment"]["dkim"] new_report["passed_dmarc"] = record["alignment"]["dmarc"] - new_report["header_from"] = record["identifiers"][ - "header_from"] - new_report["envelope_from"] = record["identifiers"][ - "envelope_from"] + new_report["header_from"] = record["identifiers"]["header_from"] + new_report["envelope_from"] = record["identifiers"]["envelope_from"] if "dkim" in record["auth_results"]: - new_report["dkim_results"] = record["auth_results"][ - "dkim"] + new_report["dkim_results"] = record["auth_results"]["dkim"] if "spf" in record["auth_results"]: - new_report["spf_results"] = record["auth_results"][ - "spf"] + new_report["spf_results"] = record["auth_results"]["spf"] data["sourcetype"] = "dmarc:aggregate" - timestamp = human_timestamp_to_timestamp( - new_report["begin_date"]) + timestamp = human_timestamp_to_timestamp(new_report["begin_date"]) data["time"] = timestamp data["event"] = new_report.copy() json_str += "{0}\n".format(json.dumps(data)) @@ -113,8 +100,7 @@ def save_aggregate_reports_to_splunk(self, aggregate_reports): if not self.session.verify: logger.debug("Skipping certificate verification for Splunk HEC") try: - response = self.session.post(self.url, data=json_str, - timeout=self.timeout) + response = self.session.post(self.url, data=json_str, timeout=self.timeout) response = response.json() except Exception as e: raise SplunkError(e.__str__()) @@ -140,8 +126,7 @@ def save_forensic_reports_to_splunk(self, forensic_reports): for report in forensic_reports: data = self._common_data.copy() data["sourcetype"] = "dmarc:forensic" - timestamp = human_timestamp_to_timestamp( - report["arrival_date_utc"]) + timestamp = human_timestamp_to_timestamp(report["arrival_date_utc"]) data["time"] = timestamp data["event"] = report.copy() json_str += "{0}\n".format(json.dumps(data)) @@ -149,8 +134,7 @@ def save_forensic_reports_to_splunk(self, forensic_reports): if not self.session.verify: logger.debug("Skipping certificate verification for Splunk HEC") try: - response = self.session.post(self.url, data=json_str, - timeout=self.timeout) + response = self.session.post(self.url, data=json_str, timeout=self.timeout) response = response.json() except Exception as e: raise SplunkError(e.__str__()) diff --git a/parsedmarc/syslog.py b/parsedmarc/syslog.py index 0aa7e1c5..d01785ef 100644 --- a/parsedmarc/syslog.py +++ b/parsedmarc/syslog.py @@ -4,8 +4,7 @@ import logging.handlers import json -from parsedmarc import parsed_aggregate_reports_to_csv_rows, \ - parsed_forensic_reports_to_csv_rows +from parsedmarc import parsed_aggregate_reports_to_csv_rows, parsed_forensic_reports_to_csv_rows class SyslogClient(object): @@ -20,10 +19,9 @@ def __init__(self, server_name, server_port): """ self.server_name = server_name self.server_port = server_port - self.logger = logging.getLogger('parsedmarc_syslog') + self.logger = logging.getLogger("parsedmarc_syslog") self.logger.setLevel(logging.INFO) - log_handler = logging.handlers.SysLogHandler(address=(server_name, - server_port)) + log_handler = logging.handlers.SysLogHandler(address=(server_name, server_port)) self.logger.addHandler(log_handler) def save_aggregate_report_to_syslog(self, aggregate_reports): diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index 08b09924..fa7f07a1 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -17,6 +17,7 @@ import mailbox import re from typing import List, Dict, Any, Optional, Union + try: import importlib.resources as pkg_resources except ImportError: @@ -36,7 +37,7 @@ import parsedmarc.resources.dbip -parenthesis_regex = re.compile(r'\s*\(.*\)\s*') +parenthesis_regex = re.compile(r"\s*\(.*\)\s*") null_file = open(os.devnull, "w") mailparser_logger = logging.getLogger("mailparser") @@ -62,8 +63,7 @@ class DownloadError(RuntimeError): def decode_base64(data: str) -> bytes: - """ - Decodes a base64 string, with padding being optional + """Decodes a base64 string, with padding being optional Args: data: A base64 encoded string @@ -75,13 +75,12 @@ def decode_base64(data: str) -> bytes: data_bytes = bytes(data, encoding="ascii") missing_padding = len(data_bytes) % 4 if missing_padding != 0: - data_bytes += b'=' * (4 - missing_padding) + data_bytes += b"=" * (4 - missing_padding) return base64.b64decode(data_bytes) def get_base_domain(domain: str) -> str: - """ - Gets the base domain name for the given domain + """Get the base domain name for the given domain .. note:: Results are based on a list of public domain suffixes at @@ -105,8 +104,7 @@ def query_dns( nameservers: Optional[List[str]] = None, timeout: float = 2.0, ) -> List[str]: - """ - Queries DNS + """Make a DNS query Args: domain: The domain or subdomain to query about @@ -130,24 +128,32 @@ def query_dns( resolver = dns.resolver.Resolver() timeout = float(timeout) if nameservers is None: - nameservers = ["1.1.1.1", "1.0.0.1", - "2606:4700:4700::1111", "2606:4700:4700::1001", - ] + nameservers = [ + "1.1.1.1", + "1.0.0.1", + "2606:4700:4700::1111", + "2606:4700:4700::1001", + ] resolver.nameservers = nameservers resolver.timeout = timeout resolver.lifetime = timeout if record_type == "TXT": - resource_records = list(map( - lambda r: r.strings, - resolver.resolve(domain, record_type, lifetime=timeout))) + resource_records = list( + map(lambda r: r.strings, resolver.resolve(domain, record_type, lifetime=timeout)) + ) _resource_record = [ resource_record[0][:0].join(resource_record) - for resource_record in resource_records if resource_record] + for resource_record in resource_records + if resource_record + ] records = [r.decode() for r in _resource_record] else: - records = list(map( - lambda r: r.to_text().replace('"', '').rstrip("."), - resolver.resolve(domain, record_type, lifetime=timeout))) + records = list( + map( + lambda r: r.to_text().replace('"', "").rstrip("."), + resolver.resolve(domain, record_type, lifetime=timeout), + ) + ) if cache: cache[cache_key] = records @@ -160,14 +166,12 @@ def get_reverse_dns( nameservers: Optional[List[str]] = None, timeout: float = 2.0, ) -> Optional[str]: - """ - Resolves an IP address to a hostname using a reverse DNS query + """Resolve an IP address to a hostname using a reverse DNS query Args: ip_address: The IP address to resolve cache: Cache storage - nameservers: A list of one or more nameservers to use - (Cloudflare's public DNS resolvers by default) + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) timeout: Sets the DNS query timeout in seconds Returns: @@ -176,9 +180,9 @@ def get_reverse_dns( hostname: Optional[str] = None try: address = str(dns.reversename.from_address(ip_address)) - hostname = query_dns(address, "PTR", cache=cache, - nameservers=nameservers, - timeout=timeout)[0] + hostname = query_dns(address, "PTR", cache=cache, nameservers=nameservers, timeout=timeout)[ + 0 + ] except dns.exception.DNSException: pass @@ -187,8 +191,7 @@ def get_reverse_dns( def timestamp_to_datetime(timestamp: int) -> datetime: - """ - Converts a UNIX/DMARC timestamp to a Python ``datetime`` object + """Converts a UNIX/DMARC timestamp to a Python ``datetime`` object Args: timestamp: The timestamp @@ -200,8 +203,7 @@ def timestamp_to_datetime(timestamp: int) -> datetime: def timestamp_to_human(timestamp: int) -> str: - """ - Converts a UNIX/DMARC timestamp to a human-readable string + """Converts a UNIX/DMARC timestamp to a human-readable string Args: timestamp: The timestamp @@ -213,8 +215,7 @@ def timestamp_to_human(timestamp: int) -> str: def human_timestamp_to_datetime(human_timestamp: str, to_utc: bool = False) -> datetime: - """ - Converts a human-readable timestamp into a Python ``datetime`` object + """Converts a human-readable timestamp into a Python ``datetime`` object Args: human_timestamp: A timestamp string @@ -232,8 +233,7 @@ def human_timestamp_to_datetime(human_timestamp: str, to_utc: bool = False) -> d def human_timestamp_to_timestamp(human_timestamp: str) -> float: - """ - Converts a human-readable timestamp into a UNIX timestamp + """Converts a human-readable timestamp into a UNIX timestamp Args: human_timestamp: A timestamp in `YYYY-MM-DD HH:MM:SS`` format @@ -246,9 +246,7 @@ def human_timestamp_to_timestamp(human_timestamp: str) -> float: def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Optional[str]: - """ - Returns the ISO code for the country associated - with the given IPv4 or IPv6 address + """Get the ISO code for the country associated with the given IPv4 or IPv6 address Args: ip_address: The IP address to query for @@ -264,8 +262,7 @@ def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Op "/var/lib/GeoIP/GeoLite2-Country.mmdb", "/var/local/lib/GeoIP/GeoLite2-Country.mmdb", "/usr/local/var/GeoIP/GeoLite2-Country.mmdb", - "%SystemDrive%\\ProgramData\\MaxMind\\GeoIPUpdate\\GeoIP\\" - "GeoLite2-Country.mmdb", + "%SystemDrive%\\ProgramData\\MaxMind\\GeoIPUpdate\\GeoIP\\" "GeoLite2-Country.mmdb", "C:\\GeoIP\\GeoLite2-Country.mmdb", "dbip-country-lite.mmdb", "dbip-country.mmdb", @@ -274,9 +271,11 @@ def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Op if db_path is not None: if os.path.isfile(db_path) is False: db_path = None - logger.warning(f"No file exists at {db_path}. Falling back to an " - "included copy of the IPDB IP to Country " - "Lite database.") + logger.warning( + f"No file exists at {db_path}. Falling back to an " + "included copy of the IPDB IP to Country " + "Lite database." + ) if db_path is None: for system_path in db_paths: @@ -285,12 +284,10 @@ def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Op break if db_path is None: - with pkg_resources.path(parsedmarc.resources.dbip, - "dbip-country-lite.mmdb") as path: + with pkg_resources.path(parsedmarc.resources.dbip, "dbip-country-lite.mmdb") as path: db_path = str(path) - db_age = datetime.now() - datetime.fromtimestamp( - os.stat(db_path).st_mtime) + db_age = datetime.now() - datetime.fromtimestamp(os.stat(db_path).st_mtime) if db_age > timedelta(days=30): logger.warning("IP database is more than a month old") @@ -313,18 +310,16 @@ def get_ip_address_info( offline: bool = False, nameservers: Optional[List[str]] = None, timeout: float = 2.0, - parallel: bool = False + parallel: bool = False, ) -> OrderedDict: - """ - Returns reverse DNS and country information for the given IP address + """Get reverse DNS and country information for the given IP address Args: ip_address: The IP address to check ip_db_path: path to a MMDB file from MaxMind or DBIP cache: Cache storage offline: Do not make online queries for geolocation or DNS - nameservers: A list of one or more nameservers to use - (Cloudflare's public DNS resolvers by default) + nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) timeout: Sets the DNS timeout in seconds parallel: parallel processing (not used) @@ -342,9 +337,7 @@ def get_ip_address_info( if offline: reverse_dns = None else: - reverse_dns = get_reverse_dns(ip_address, - nameservers=nameservers, - timeout=timeout) + reverse_dns = get_reverse_dns(ip_address, nameservers=nameservers, timeout=timeout) country = get_ip_address_country(ip_address, db_path=ip_db_path) info["country"] = country info["reverse_dns"] = reverse_dns @@ -357,6 +350,7 @@ def get_ip_address_info( def parse_email_address(original_address: str) -> OrderedDict: + """Parse an email into parts""" if original_address[0] == "": display_name = None else: @@ -369,15 +363,13 @@ def parse_email_address(original_address: str) -> OrderedDict: local = address_parts[0].lower() domain = address_parts[-1].lower() - return OrderedDict([("display_name", display_name), - ("address", address), - ("local", local), - ("domain", domain)]) + return OrderedDict( + [("display_name", display_name), ("address", address), ("local", local), ("domain", domain)] + ) def get_filename_safe_string(string: str) -> str: - """ - Converts a string to a string that is safe for a filename + """Convert a string to a string that is safe for a filename Args: string: A string to make safe for a filename @@ -385,8 +377,7 @@ def get_filename_safe_string(string: str) -> str: Returns: A string safe for a filename """ - invalid_filename_chars = ['\\', '/', ':', '"', '*', '?', '|', '\n', - '\r'] + invalid_filename_chars = ["\\", "/", ":", '"', "*", "?", "|", "\n", "\r"] if string is None: string = "None" for char in invalid_filename_chars: @@ -399,8 +390,7 @@ def get_filename_safe_string(string: str) -> str: def is_mbox(path: str) -> bool: - """ - Checks if the given content is an MBOX mailbox file + """Checks if the given content is an MBOX mailbox file Args: Content to check @@ -420,8 +410,7 @@ def is_mbox(path: str) -> bool: def is_outlook_msg(content: Any) -> bool: - """ - Checks if the given content is an Outlook msg OLE/MSG file + """Checks if the given content is an Outlook msg OLE/MSG file Args: content: Content to check @@ -429,14 +418,13 @@ def is_outlook_msg(content: Any) -> bool: Returns: If the file is an Outlook MSG file """ - return isinstance(content, bytes) and content.startswith( - b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1") + return isinstance(content, bytes) and content.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1") def convert_outlook_msg(msg_bytes: bytes) -> bytes: - """ - Uses the ``msgconvert`` Perl utility to convert an Outlook MS file to - standard RFC 822 format + """Convert an Outlook MS file to standard RFC 822 format + + Requires the ``msgconvert`` Perl utility to be installed. Args: msg_bytes: the content of the .msg file @@ -452,14 +440,12 @@ def convert_outlook_msg(msg_bytes: bytes) -> bytes: with open("sample.msg", "wb") as msg_file: msg_file.write(msg_bytes) try: - subprocess.check_call(["msgconvert", "sample.msg"], - stdout=null_file, stderr=null_file) + subprocess.check_call(["msgconvert", "sample.msg"], stdout=null_file, stderr=null_file) eml_path = "sample.eml" with open(eml_path, "rb") as eml_file: rfc822 = eml_file.read() except FileNotFoundError: - raise EmailParserError( - "Failed to convert Outlook MSG: msgconvert utility not found") + raise EmailParserError("Failed to convert Outlook MSG: msgconvert utility not found") finally: os.chdir(orig_dir) shutil.rmtree(tmp_dir) @@ -468,8 +454,7 @@ def convert_outlook_msg(msg_bytes: bytes) -> bytes: def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False) -> Dict[str, Any]: - """ - A simplified email parser + """A simplified email parser Args: data: The RFC 822 message string, or MSG binary @@ -494,8 +479,7 @@ def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False if received["date_utc"] is None: del received["date_utc"] else: - received["date_utc"] = received["date_utc"].replace("T", - " ") + received["date_utc"] = received["date_utc"].replace("T", " ") if "from" not in parsed_email: if "From" in parsed_email["headers"]: @@ -511,33 +495,30 @@ def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False else: parsed_email["date"] = None if "reply_to" in parsed_email: - parsed_email["reply_to"] = list(map(lambda x: parse_email_address(x), - parsed_email["reply_to"])) + parsed_email["reply_to"] = list( + map(lambda x: parse_email_address(x), parsed_email["reply_to"]) + ) else: parsed_email["reply_to"] = [] if "to" in parsed_email: - parsed_email["to"] = list(map(lambda x: parse_email_address(x), - parsed_email["to"])) + parsed_email["to"] = list(map(lambda x: parse_email_address(x), parsed_email["to"])) else: parsed_email["to"] = [] if "cc" in parsed_email: - parsed_email["cc"] = list(map(lambda x: parse_email_address(x), - parsed_email["cc"])) + parsed_email["cc"] = list(map(lambda x: parse_email_address(x), parsed_email["cc"])) else: parsed_email["cc"] = [] if "bcc" in parsed_email: - parsed_email["bcc"] = list(map(lambda x: parse_email_address(x), - parsed_email["bcc"])) + parsed_email["bcc"] = list(map(lambda x: parse_email_address(x), parsed_email["bcc"])) else: parsed_email["bcc"] = [] if "delivered_to" in parsed_email: parsed_email["delivered_to"] = list( - map(lambda x: parse_email_address(x), - parsed_email["delivered_to"]) + map(lambda x: parse_email_address(x), parsed_email["delivered_to"]) ) if "attachments" not in parsed_email: @@ -554,9 +535,7 @@ def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False payload = str.encode(payload) attachment["sha256"] = hashlib.sha256(payload).hexdigest() except Exception as e: - logger.debug("Unable to decode attachment: {0}".format( - e.__str__() - )) + logger.debug("Unable to decode attachment: {0}".format(e.__str__())) if strip_attachment_payloads: for attachment in parsed_email["attachments"]: if "payload" in attachment: @@ -565,8 +544,7 @@ def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False if "subject" not in parsed_email: parsed_email["subject"] = None - parsed_email["filename_safe_subject"] = get_filename_safe_string( - parsed_email["subject"]) + parsed_email["filename_safe_subject"] = get_filename_safe_string(parsed_email["subject"]) if "body" not in parsed_email: parsed_email["body"] = None From 211c237ccb2d2e8643f2f00cf28f0cf49a7a64d9 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Wed, 3 Jan 2024 02:51:18 +1100 Subject: [PATCH 04/15] [typing] More typing and some refactoring --- parsedmarc/__init__.py | 141 ++++++++++++-------------------- parsedmarc/cli.py | 46 +++++------ parsedmarc/elastic.py | 168 ++++++++++++++++++-------------------- parsedmarc/kafkaclient.py | 84 ++++++++++--------- parsedmarc/mail/gmail.py | 9 ++ parsedmarc/mail/graph.py | 12 +++ parsedmarc/mail/imap.py | 30 +++++-- parsedmarc/s3.py | 38 +++++---- parsedmarc/splunk.py | 84 +++++++++++-------- parsedmarc/syslog.py | 9 +- parsedmarc/utils.py | 22 +++-- tests.py | 21 +++-- 12 files changed, 339 insertions(+), 325 deletions(-) diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 48c3b74b..fd8bbec8 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """A Python package for parsing DMARC reports""" import binascii @@ -36,7 +34,7 @@ __version__ = "8.6.4" -logger.debug("parsedmarc v{0}".format(__version__)) +logger.debug(f"parsedmarc v{__version__}") feedback_report_regex = re.compile(r"^([\w\-]+): (.+)$", re.MULTILINE) xml_header_regex = re.compile(r"^<\?xml .*?>", re.MULTILINE) @@ -241,7 +239,7 @@ def parse_aggregate_report_xml( try: xmltodict.parse(xml)["feedback"] except Exception as e: - errors.append("Invalid XML: {0}".format(e.__str__())) + errors.append(f"Invalid XML: {e!r}") try: tree = etree.parse( BytesIO(xml.encode("utf-8")), etree.XMLParser(recover=True, resolve_entities=False) @@ -274,11 +272,9 @@ def parse_aggregate_report_xml( if new_org_name is not None: org_name = new_org_name if not org_name: - logger.debug("Could not parse org_name from XML.\r\n{0}".format(report.__str__())) + logger.debug(f"Could not parse org_name from XML.\r\n{report}") raise KeyError( - "Organization name is missing. \ - This field is a requirement for \ - saving the report" + "Organization name is missing. This field is a requirement for saving the report" ) new_report_metadata["org_name"] = org_name new_report_metadata["org_email"] = report_metadata["email"] @@ -338,11 +334,12 @@ def parse_aggregate_report_xml( new_report["policy_published"] = new_policy_published if type(report["record"]) is list: - for i in range(len(report["record"])): + record_count = len(report["record"]) + for i in range(record_count): if keep_alive is not None and i > 0 and i % 20 == 0: logger.debug("Sending keepalive cmd") keep_alive() - logger.debug("Processed {0}/{1}".format(i, len(report["record"]))) + logger.debug(f"Processed {i}/{record_count}") report_record = _parse_report_record( report["record"][i], ip_db_path=ip_db_path, @@ -369,15 +366,15 @@ def parse_aggregate_report_xml( return new_report except expat.ExpatError as error: - raise InvalidAggregateReport("Invalid XML: {0}".format(error.__str__())) + raise InvalidAggregateReport(f"Invalid XML: {error!r}") except KeyError as error: - raise InvalidAggregateReport("Missing field: {0}".format(error.__str__())) + raise InvalidAggregateReport(f"Missing field: {error!r}") except AttributeError: raise InvalidAggregateReport("Report missing required section") except Exception as error: - raise InvalidAggregateReport("Unexpected error: {0}".format(error.__str__())) + raise InvalidAggregateReport(f"Unexpected error: {error!r}") def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: @@ -388,7 +385,6 @@ def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: Returns: The extracted XML - """ file_object: BinaryIO try: @@ -425,7 +421,7 @@ def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: raise InvalidAggregateReport("File objects must be opened in binary " "(rb) mode") except Exception as error: file_object.close() - raise InvalidAggregateReport("Invalid archive file: {0}".format(error.__str__())) + raise InvalidAggregateReport(f"Invalid archive file: {error!r}") return xml @@ -659,11 +655,9 @@ def parse_forensic_report( msg_date: The message's date header offline: Do not query online for geolocation or DNS ip_db_path: Path to a MMDB file from MaxMind or DBIP - nameservers (list): A list of one or more nameservers to use - (Cloudflare's public DNS resolvers by default) + nameservers (list): A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) dns_timeout: Sets the DNS timeout in seconds - strip_attachment_payloads: Remove attachment payloads from - forensic report results + strip_attachment_payloads: Remove attachment payloads from forensic report results parallel: Parallel processing Returns: @@ -761,10 +755,10 @@ def parse_forensic_report( return parsed_report except KeyError as error: - raise InvalidForensicReport("Missing value: {0}".format(error.__str__())) + raise InvalidForensicReport(f"Missing value: {error!r}") except Exception as error: - raise InvalidForensicReport("Unexpected error: {0}".format(error.__str__())) + raise InvalidForensicReport("Unexpected error: {error!r}") def parsed_forensic_reports_to_csv_rows( @@ -867,8 +861,7 @@ def parse_report_email( ip_db_path: Path to a MMDB file from MaxMind or DBIP nameservers: A list of one or more nameservers to use dns_timeout: Sets the DNS timeout in seconds - strip_attachment_payloads: Remove attachment payloads from - forensic report results + strip_attachment_payloads: Remove attachment payloads from forensic report results parallel: Parallel processing keep_alive: keep alive function @@ -894,7 +887,7 @@ def parse_report_email( feedback_report = None sample = None if "From" in msg_headers: - logger.info("Parsing mail from {0}".format(msg_headers["From"])) + logger.info(f"Parsing mail from {msg_headers['From']}") if "Subject" in msg_headers: subject = msg_headers["Subject"] for part in msg.walk(): @@ -927,11 +920,7 @@ def parse_report_email( for match in field_matches: field_name = match[0].lower().replace(" ", "-") fields[field_name] = match[1].strip() - feedback_report = ( - "Arrival-Date: {}\n" - "Source-IP: {}" - "".format(fields["received-date"], fields["sender-ip-address"]) - ) + feedback_report = f"Arrival-Date: {fields['received-date']}\nSource-IP: {fields['sender-ip-address']}" sample = parts[1].lstrip() sample = sample.replace("=\r\n", "") logger.debug(sample) @@ -960,14 +949,12 @@ def parse_report_email( except InvalidAggregateReport as e: error = ( - 'Message with subject "{0}" ' - "is not a valid " - "aggregate DMARC report: {1}".format(subject, e) + f"Message with subject {subject!r} is not a valid aggregate DMARC report: {e!r}" ) raise InvalidAggregateReport(error) except Exception as e: - error = "Unable to parse message with " 'subject "{0}": {1}'.format(subject, e) + error = f"Unable to parse message with subject {subject!r}: {e!r}" raise InvalidDMARCReport(error) if feedback_report and sample: @@ -983,18 +970,14 @@ def parse_report_email( parallel=parallel, ) except InvalidForensicReport as e: - error = ( - 'Message with subject "{0}" ' - "is not a valid " - "forensic DMARC report: {1}".format(subject, e) - ) + error = f"Message with subject {subject!r} is not a valid forensic DMARC report: {e!r}" raise InvalidForensicReport(error) except Exception as e: - raise InvalidForensicReport(e.__str__()) + raise InvalidForensicReport(repr(e)) return OrderedDict([("report_type", "forensic"), ("report", forensic_report)]) - error = 'Message with subject "{0}" is not a valid DMARC report'.format(subject) + error = f"Message with subject {subject!r} is not a valid DMARC report" raise InvalidDMARCReport(error) @@ -1025,7 +1008,7 @@ def parse_report_file( """ file_object: BinaryIO if isinstance(input_, str): - logger.debug("Parsing {0}".format(input_)) + logger.debug(f"Parsing {input_}") file_object = open(input_, "rb") elif isinstance(input_, bytes): file_object = BytesIO(input_) @@ -1087,7 +1070,6 @@ def get_dmarc_reports_from_mbox( Returns: Dictionary of Lists of ``aggregate_reports`` and ``forensic_reports`` - """ aggregate_reports: List[OrderedDict] = [] forensic_reports: List[OrderedDict] = [] @@ -1095,10 +1077,10 @@ def get_dmarc_reports_from_mbox( mbox = mailbox.mbox(input_) message_keys = mbox.keys() total_messages = len(message_keys) - logger.debug("Found {0} messages in {1}".format(total_messages, input_)) + logger.debug(f"Found {total_messages} messages in {input_}") for i in range(len(message_keys)): message_key = message_keys[i] - logger.info("Processing message {0} of {1}".format(i + 1, total_messages)) + logger.info(f"Processing message {i+1} of {total_messages}") msg_content = mbox.get_string(message_key) try: sa = strip_attachment_payloads @@ -1118,7 +1100,7 @@ def get_dmarc_reports_from_mbox( except InvalidDMARCReport as error: logger.warning(error.__str__()) except mailbox.NoSuchMailboxError: - raise InvalidDMARCReport("Mailbox {0} does not exist".format(input_)) + raise InvalidDMARCReport(f"Mailbox {input_} does not exist") return OrderedDict( [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] ) @@ -1169,9 +1151,9 @@ def get_dmarc_reports_from_mailbox( forensic_reports = [] aggregate_report_msg_uids = [] forensic_report_msg_uids = [] - aggregate_reports_folder = "{0}/Aggregate".format(archive_folder) - forensic_reports_folder = "{0}/Forensic".format(archive_folder) - invalid_reports_folder = "{0}/Invalid".format(archive_folder) + aggregate_reports_folder = f"{archive_folder}/Aggregate" + forensic_reports_folder = f"{archive_folder}/Forensic" + invalid_reports_folder = f"{archive_folder}/Invalid" if results: aggregate_reports = results["aggregate_reports"].copy() @@ -1185,18 +1167,18 @@ def get_dmarc_reports_from_mailbox( messages = connection.fetch_messages(reports_folder, batch_size=batch_size) total_messages = len(messages) - logger.debug("Found {0} messages in {1}".format(len(messages), reports_folder)) + logger.debug(f"Found {len(messages)} messages in {reports_folder}") if batch_size: message_limit = min(total_messages, batch_size) else: message_limit = total_messages - logger.debug("Processing {0} messages".format(message_limit)) + logger.debug(f"Processing {message_limit} messages") for i in range(message_limit): msg_uid = messages[i] - logger.debug("Processing message {0} of {1}: UID {2}".format(i + 1, message_limit, msg_uid)) + logger.debug(f"Processing message {i+1} of {message_limit}: UID {msg_uid}") msg_content = connection.fetch_message(msg_uid) try: sa = strip_attachment_payloads @@ -1219,12 +1201,10 @@ def get_dmarc_reports_from_mailbox( logger.warning(error.__str__()) if not test: if delete: - logger.debug("Deleting message UID {0}".format(msg_uid)) + logger.debug(f"Deleting message UID {msg_uid}") connection.delete_message(msg_uid) else: - logger.debug( - "Moving message UID {0} to {1}".format(msg_uid, invalid_reports_folder) - ) + logger.debug(f"Moving message UID {msg_uid} to {invalid_reports_folder}") connection.move_message(msg_uid, invalid_reports_folder) if not test: @@ -1234,61 +1214,42 @@ def get_dmarc_reports_from_mailbox( number_of_processed_msgs = len(processed_messages) for i in range(number_of_processed_msgs): msg_uid = processed_messages[i] - logger.debug( - "Deleting message {0} of {1}: UID {2}".format( - i + 1, number_of_processed_msgs, msg_uid - ) - ) + logger.debug(f"Deleting message {i+1} of {number_of_processed_msgs}: UID {msg_uid}") try: connection.delete_message(msg_uid) except Exception as e: - message = "Mailbox error: Error deleting message UID {0}: {1}".format( - msg_uid, repr(e) - ) - logger.error(message) + logger.error(f"Mailbox error: Error deleting message UID {msg_uid}: {e!r}") else: if len(aggregate_report_msg_uids) > 0: - log_message = "Moving aggregate report messages from" logger.debug( - "{0} {1} to {2}".format(log_message, reports_folder, aggregate_reports_folder) + f"Moving aggregate report messages from {reports_folder} to {aggregate_reports_folder}" ) number_of_agg_report_msgs = len(aggregate_report_msg_uids) for i in range(number_of_agg_report_msgs): msg_uid = aggregate_report_msg_uids[i] logger.debug( - "Moving message {0} of {1}: UID {2}".format( - i + 1, number_of_agg_report_msgs, msg_uid - ) + f"Moving message {i+1} of {number_of_agg_report_msgs}: UID {msg_uid}" ) try: connection.move_message(msg_uid, aggregate_reports_folder) except Exception as e: - message = "Mailbox error: Error moving message UID {0}: {1}".format( - msg_uid, repr(e) - ) - logger.error(message) + logger.error(f"Mailbox error: Error moving message UID {msg_uid}: {e!r}") if len(forensic_report_msg_uids) > 0: - message = "Moving forensic report messages from" logger.debug( - "{0} {1} to {2}".format(message, reports_folder, forensic_reports_folder) + f"Moving forensic report messages from {reports_folder} to {forensic_reports_folder}" ) number_of_forensic_msgs = len(forensic_report_msg_uids) for i in range(number_of_forensic_msgs): msg_uid = forensic_report_msg_uids[i] - message = "Moving message" logger.debug( - "{0} {1} of {2}: UID {3}".format( - message, i + 1, number_of_forensic_msgs, msg_uid - ) + f"Moving message {i+1} of {number_of_forensic_msgs}: UID {msg_uid}" ) try: connection.move_message(msg_uid, forensic_reports_folder) except Exception as e: - message = "Mailbox error: Error moving message UID {0}: {1}".format( - msg_uid, repr(e) - ) - logger.error(message) + logger.error(f"Mailbox error: Error moving message UID {msg_uid}: {e!r}") + results = OrderedDict( [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] ) @@ -1429,7 +1390,7 @@ def save_output( if os.path.exists(output_directory): if not os.path.isdir(output_directory): - raise ValueError("{0} is not a directory".format(output_directory)) + raise ValueError(f"{output_directory} is not a directory") else: os.makedirs(output_directory) @@ -1461,11 +1422,11 @@ def save_output( while filename in sample_filenames: message_count += 1 - filename = "{0} ({1})".format(subject, message_count) + filename = f"{subject} ({message_count})" sample_filenames.append(filename) - filename = "{0}.eml".format(filename) + filename = f"{filename}.eml" path = os.path.join(samples_directory, filename) with open(path, "w", newline="\n", encoding="utf-8") as sample_file: sample_file.write(sample) @@ -1550,21 +1511,21 @@ def email_results( attachment_filename: Override the default attachment filename message: Override the default plain text body """ - logger.debug("Emailing report to: {0}".format(",".join(mail_to))) + logger.debug(f"Emailing report to: {''.join(mail_to)}") date_string = datetime.now().strftime("%Y-%m-%d") if attachment_filename: if not attachment_filename.lower().endswith(".zip"): attachment_filename += ".zip" filename = attachment_filename else: - filename = "DMARC-{0}.zip".format(date_string) + filename = f"DMARC-{date_string}.zip" assert isinstance(mail_to, list) if subject is None: - subject = "DMARC results for {0}".format(date_string) + subject = f"DMARC results for {date_string}" if message is None: - message = "DMARC results for {0}".format(date_string) + message = f"DMARC results for {date_string}" zip_bytes = get_report_zip(results) attachments = [(filename, zip_bytes)] diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index f07201ae..587b4663 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -1,6 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- - """A CLI for parsing DMARC reports""" from argparse import Namespace, ArgumentParser @@ -85,7 +83,7 @@ def _main(): """Called when the module is executed""" def process_reports(reports_): - output_str = "{0}\n".format(json.dumps(reports_, ensure_ascii=False, indent=2)) + output_str = json.dumps(reports_, ensure_ascii=False, indent=2) + "\n" if not opts.silent: print(output_str) @@ -114,31 +112,31 @@ def process_reports(reports_): except elastic.AlreadySaved as warning: logger.warning(warning.__str__()) except elastic.ElasticsearchError as error_: - logger.error("Elasticsearch Error: {0}".format(error_.__str__())) + logger.error(f"Elasticsearch Error: {error_!r}") except Exception as error_: - logger.error("Elasticsearch exception error: {}".format(error_.__str__())) + logger.error(f"Elasticsearch exception error: {error_!r}") try: if opts.kafka_hosts: kafka_client.save_aggregate_reports_to_kafka(report, kafka_aggregate_topic) except Exception as error_: - logger.error("Kafka Error: {0}".format(error_.__str__())) + logger.error(f"Kafka Error: {error_!r}") try: if opts.s3_bucket: s3_client.save_aggregate_report_to_s3(report) except Exception as error_: - logger.error("S3 Error: {0}".format(error_.__str__())) + logger.error(f"S3 Error: {error_!r}") try: if opts.syslog_server: syslog_client.save_aggregate_report_to_syslog(report) except Exception as error_: - logger.error("Syslog Error: {0}".format(error_.__str__())) + logger.error(f"Syslog Error: {error_!r}") if opts.hec: try: aggregate_reports_ = reports_["aggregate_reports"] if len(aggregate_reports_) > 0: hec_client.save_aggregate_reports_to_splunk(aggregate_reports_) except splunk.SplunkError as e: - logger.error("Splunk HEC error: {0}".format(e.__str__())) + logger.error(f"Splunk HEC error: {e!r}") if opts.save_forensic: for report in reports_["forensic_reports"]: try: @@ -155,31 +153,31 @@ def process_reports(reports_): except elastic.AlreadySaved as warning: logger.warning(warning.__str__()) except elastic.ElasticsearchError as error_: - logger.error("Elasticsearch Error: {0}".format(error_.__str__())) + logger.error(f"Elasticsearch Error: {error_!r}") except InvalidDMARCReport as error_: logger.error(error_.__str__()) try: if opts.kafka_hosts: kafka_client.save_forensic_reports_to_kafka(report, kafka_forensic_topic) except Exception as error_: - logger.error("Kafka Error: {0}".format(error_.__str__())) + logger.error(f"Kafka Error: {error_!r}") try: if opts.s3_bucket: s3_client.save_forensic_report_to_s3(report) except Exception as error_: - logger.error("S3 Error: {0}".format(error_.__str__())) + logger.error(f"S3 Error: {error_!r}") try: if opts.syslog_server: syslog_client.save_forensic_report_to_syslog(report) except Exception as error_: - logger.error("Syslog Error: {0}".format(error_.__str__())) + logger.error(f"Syslog Error: {error_!r}") if opts.hec: try: forensic_reports_ = reports_["forensic_reports"] if len(forensic_reports_) > 0: hec_client.save_forensic_reports_to_splunk(forensic_reports_) except splunk.SplunkError as e: - logger.error("Splunk HEC error: {0}".format(e.__str__())) + logger.error(f"Splunk HEC error: {e!r}") if opts.la_dce: try: la_client = loganalytics.LogAnalyticsClient( @@ -193,7 +191,7 @@ def process_reports(reports_): ) la_client.publish_results(reports_, opts.save_aggregate, opts.save_forensic) except loganalytics.LogAnalyticsException as e: - logger.error("Log Analytics error: {0}".format(e.__str__())) + logger.error(f"Log Analytics error: {e!r}") except Exception as e: logger.error( "Unknown error occured" @@ -368,7 +366,7 @@ def process_reports(reports_): if args.config_file: abs_path = os.path.abspath(args.config_file) if not os.path.exists(abs_path): - logger.error("A file does not exist at {0}".format(abs_path)) + logger.error(f"A file does not exist at {abs_path}") exit(-1) opts.silent = True config = ConfigParser() @@ -773,7 +771,7 @@ def process_reports(reports_): fh.setFormatter(formatter) logger.addHandler(fh) except Exception as error: - logger.warning("Unable to write to log file: {}".format(error)) + logger.warning(f"Unable to write to log file: {error!r}") if ( opts.imap_host is None @@ -793,8 +791,8 @@ def process_reports(reports_): es_forensic_index = "dmarc_forensic" if opts.elasticsearch_index_suffix: suffix = opts.elasticsearch_index_suffix - es_aggregate_index = "{0}_{1}".format(es_aggregate_index, suffix) - es_forensic_index = "{0}_{1}".format(es_forensic_index, suffix) + es_aggregate_index = f"{es_aggregate_index}_{suffix}" + es_forensic_index = f"{es_forensic_index}_{suffix}" elastic.set_hosts( opts.elasticsearch_hosts, opts.elasticsearch_ssl, @@ -822,7 +820,7 @@ def process_reports(reports_): secret_access_key=opts.s3_secret_access_key, ) except Exception as error_: - logger.error("S3 Error: {0}".format(error_.__str__())) + logger.error(f"S3 Error: {error_!r}") if opts.syslog_server: try: @@ -831,7 +829,7 @@ def process_reports(reports_): server_port=int(opts.syslog_port), ) except Exception as error_: - logger.error("Syslog Error: {0}".format(error_.__str__())) + logger.error(f"Syslog Error: {error_!r}") if opts.hec: if opts.hec_token is None or opts.hec_index is None: @@ -858,7 +856,7 @@ def process_reports(reports_): ssl_context=ssl_context, ) except Exception as error_: - logger.error("Kafka Error: {0}".format(error_.__str__())) + logger.error(f"Kafka Error: {error!r}") kafka_aggregate_topic = opts.kafka_aggregate_topic kafka_forensic_topic = opts.kafka_forensic_topic @@ -908,7 +906,7 @@ def process_reports(reports_): for result in results: if type(result[0]) is InvalidDMARCReport: - logger.error("Failed to parse {0} - {1}".format(result[1], result[0])) + logger.error(f"Failed to parse {result[1]} - {result[0]}") else: if result[0]["report_type"] == "aggregate": aggregate_reports.append(result[0]["report"]) @@ -1070,7 +1068,7 @@ def process_reports(reports_): offline=opts.offline, ) except FileExistsError as error: - logger.error("{0}".format(error.__str__())) + logger.error(f"{error!r}") exit(1) diff --git a/parsedmarc/elastic.py b/parsedmarc/elastic.py index 14a5e817..6cd2b094 100644 --- a/parsedmarc/elastic.py +++ b/parsedmarc/elastic.py @@ -1,6 +1,5 @@ -# -*- coding: utf-8 -*- - from collections import OrderedDict +from typing import Optional, Union, List, Dict, Any from elasticsearch_dsl.search import Q from elasticsearch_dsl import ( @@ -27,6 +26,12 @@ class ElasticsearchError(Exception): """Raised when an Elasticsearch error occurs""" + def __init__(self, message: Union[str, Exception]): + if isinstance(message, Exception): + message = repr(message) + super().__init__(f"Elasticsearch Error: {message}") + return + class _PolicyOverride(InnerDoc): type = Text() @@ -173,27 +178,26 @@ class AlreadySaved(ValueError): def set_hosts( - hosts, - use_ssl=False, - ssl_cert_path=None, - username=None, - password=None, - apiKey=None, - timeout=60.0, -): - """ - Sets the Elasticsearch hosts to use + hosts: Union[str, List[str]], + use_ssl: bool = False, + ssl_cert_path: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + apiKey: Optional[str] = None, + timeout: float = 60.0, +) -> None: + """Set the Elasticsearch host(s) to use Args: - hosts (str): A single hostname or URL, or list of hostnames or URLs - use_ssl (bool): Use a HTTPS connection to the server - ssl_cert_path (str): Path to the certificate chain - username (str): The username to use for authentication - password (str): The password to use for authentication - apiKey (str): The Base64 encoded API key to use for authentication - timeout (float): Timeout in seconds + hosts: A single hostname or URL, or list of hostnames or URLs + use_ssl: Use a HTTPS connection to the server + ssl_cert_path: Path to the certificate chain + username: The username to use for authentication + password: The password to use for authentication + apiKey: The Base64 encoded API key to use for authentication + timeout: Timeout in seconds """ - if not isinstance(hosts, list): + if isinstance(hosts, str): hosts = [hosts] conn_params = {"hosts": hosts, "timeout": timeout} if use_ssl: @@ -203,43 +207,45 @@ def set_hosts( conn_params["ca_certs"] = ssl_cert_path else: conn_params["verify_certs"] = False - if username: + if username and password: conn_params["http_auth"] = username + ":" + password if apiKey: conn_params["api_key"] = apiKey connections.create_connection(**conn_params) + return -def create_indexes(names, settings=None): - """ - Create Elasticsearch indexes +def create_indexes(names: List[str], settings: Optional[Dict[str, int]] = None) -> None: + """Create Elasticsearch indexes Args: - names (list): A list of index names - settings (dict): Index settings + names: A list of index names + settings: Index settings """ for name in names: index = Index(name) try: if not index.exists(): - logger.debug("Creating Elasticsearch index: {0}".format(name)) + logger.debug(f"Creating Elasticsearch index: {name}") if settings is None: index.settings(number_of_shards=1, number_of_replicas=0) else: index.settings(**settings) index.create() except Exception as e: - raise ElasticsearchError("Elasticsearch error: {0}".format(e.__str__())) + raise ElasticsearchError(e) + return -def migrate_indexes(aggregate_indexes=None, forensic_indexes=None): - """ - Updates index mappings +def migrate_indexes( + aggregate_indexes: Optional[List[str]] = None, forensic_indexes: Optional[List[str]] = None +): + """Update index mappings Args: - aggregate_indexes (list): A list of aggregate index names - forensic_indexes (list): A list of forensic index names + aggregate_indexes: A list of aggregate index names + forensic_indexes: A list of forensic index names """ version = 2 if aggregate_indexes is None: @@ -261,7 +267,7 @@ def migrate_indexes(aggregate_indexes=None, forensic_indexes=None): fo_mapping = fo_mapping[doc][fo_field]["mapping"][fo] fo_type = fo_mapping["type"] if fo_type == "long": - new_index_name = "{0}-v{1}".format(aggregate_index_name, version) + new_index_name = f"{aggregate_index_name}-v{version}" body = { "properties": { "published_policy.fo": { @@ -280,21 +286,21 @@ def migrate_indexes(aggregate_indexes=None, forensic_indexes=None): def save_aggregate_report_to_elasticsearch( - aggregate_report, - index_suffix=None, - monthly_indexes=False, - number_of_shards=1, - number_of_replicas=0, -): + aggregate_report: OrderedDict[str, Any], + index_suffix: Optional[str] = None, + monthly_indexes: bool = False, + number_of_shards: int = 1, + number_of_replicas: int = 0, +) -> None: """ Saves a parsed DMARC aggregate report to ElasticSearch Args: - aggregate_report (OrderedDict): A parsed forensic report - index_suffix (str): The suffix of the name of the index to save to - monthly_indexes (bool): Use monthly indexes instead of daily indexes - number_of_shards (int): The number of shards to use in the index - number_of_replicas (int): The number of replicas to use in the index + aggregate_report: A parsed forensic report + index_suffix: The suffix of the name of the index to save to + monthly_indexes: Use monthly indexes instead of daily indexes + number_of_shards: The number of shards to use in the index + number_of_replicas: The number of replicas to use in the index Raises: AlreadySaved @@ -324,7 +330,7 @@ def save_aggregate_report_to_elasticsearch( end_date_query = Q(dict(match=dict(date_end=end_date))) if index_suffix is not None: - search = Search(index="dmarc_aggregate_{0}*".format(index_suffix)) + search = Search(index=f"dmarc_aggregate_{index_suffix}*") else: search = Search(index="dmarc_aggregate*") query = org_name_query & report_id_query & domain_query @@ -333,20 +339,14 @@ def save_aggregate_report_to_elasticsearch( try: existing = search.execute() - except Exception as error_: - raise ElasticsearchError( - "Elasticsearch's search for existing report \ - error: {}".format( - error_.__str__() - ) - ) + except Exception as e: + raise ElasticsearchError(f"Search for existing report error: {e!r}") if len(existing) > 0: raise AlreadySaved( - "An aggregate report ID {0} from {1} about {2} " - "with a date range of {3} UTC to {4} UTC already " - "exists in " - "Elasticsearch".format(report_id, org_name, domain, begin_date_human, end_date_human) + f"An aggregate report ID {report_id} from {org_name} about {domain} " + f"with a date range of {begin_date_human} UTC to {end_date_human} UTC already " + "exists in Elasticsearch" ) published_policy = _PublishedPolicy( domain=aggregate_report["policy_published"]["domain"], @@ -402,8 +402,8 @@ def save_aggregate_report_to_elasticsearch( index = "dmarc_aggregate" if index_suffix: - index = "{0}_{1}".format(index, index_suffix) - index = "{0}-{1}".format(index, index_date) + index = f"{index}_{index_suffix}" + index = f"{index}-{index_date}" index_settings = dict( number_of_shards=number_of_shards, number_of_replicas=number_of_replicas ) @@ -413,31 +413,28 @@ def save_aggregate_report_to_elasticsearch( try: agg_doc.save() except Exception as e: - raise ElasticsearchError("Elasticsearch error: {0}".format(e.__str__())) + raise ElasticsearchError(e) + return def save_forensic_report_to_elasticsearch( - forensic_report, - index_suffix=None, - monthly_indexes=False, - number_of_shards=1, - number_of_replicas=0, -): - """ - Saves a parsed DMARC forensic report to ElasticSearch + forensic_report: OrderedDict[str, Any], + index_suffix: Optional[str] = None, + monthly_indexes: bool = False, + number_of_shards: int = 1, + number_of_replicas: int = 0, +) -> None: + """Save a parsed DMARC forensic report to ElasticSearch Args: - forensic_report (OrderedDict): A parsed forensic report - index_suffix (str): The suffix of the name of the index to save to - monthly_indexes (bool): Use monthly indexes instead of daily - indexes - number_of_shards (int): The number of shards to use in the index - number_of_replicas (int): The number of replicas to use in the - index + forensic_report: A parsed forensic report + index_suffix: The suffix of the name of the index to save to + monthly_indexes: Use monthly indexes instead of daily indexes + number_of_shards: The number of shards to use in the index + number_of_replicas: The number of replicas to use in the index Raises: AlreadySaved - """ logger.info("Saving forensic report to Elasticsearch") forensic_report = forensic_report.copy() @@ -454,7 +451,7 @@ def save_forensic_report_to_elasticsearch( arrival_date = human_timestamp_to_datetime(arrival_date_human) if index_suffix is not None: - search = Search(index="dmarc_forensic_{0}*".format(index_suffix)) + search = Search(index=f"dmarc_forensic_{index_suffix}*") else: search = Search(index="dmarc_forensic*") arrival_query = {"match": {"arrival_date": arrival_date}} @@ -481,10 +478,8 @@ def save_forensic_report_to_elasticsearch( if len(existing) > 0: raise AlreadySaved( - "A forensic sample to {0} from {1} " - "with a subject of {2} and arrival date of {3} " - "already exists in " - "Elasticsearch".format(to_, from_, subject, arrival_date_human) + f"A forensic sample to {to_} from {from_} with a subject of {subject} " + f"and arrival date of {arrival_date_human} already exists in Elasticsearch" ) parsed_sample = forensic_report["parsed_sample"] @@ -536,12 +531,12 @@ def save_forensic_report_to_elasticsearch( index = "dmarc_forensic" if index_suffix: - index = "{0}_{1}".format(index, index_suffix) + index = f"{index}_{index_suffix}" if monthly_indexes: index_date = arrival_date.strftime("%Y-%m") else: index_date = arrival_date.strftime("%Y-%m-%d") - index = "{0}-{1}".format(index, index_date) + index = f"{index}-{index_date}" index_settings = dict( number_of_shards=number_of_shards, number_of_replicas=number_of_replicas ) @@ -550,8 +545,7 @@ def save_forensic_report_to_elasticsearch( try: forensic_doc.save() except Exception as e: - raise ElasticsearchError("Elasticsearch error: {0}".format(e.__str__())) + raise ElasticsearchError(e) except KeyError as e: - raise InvalidForensicReport( - "Forensic report missing required field: {0}".format(e.__str__()) - ) + raise InvalidForensicReport(f"Forensic report missing required field: {e!r}") + return diff --git a/parsedmarc/kafkaclient.py b/parsedmarc/kafkaclient.py index 53feeb51..61626a1d 100644 --- a/parsedmarc/kafkaclient.py +++ b/parsedmarc/kafkaclient.py @@ -1,13 +1,12 @@ -# -*- coding: utf-8 -*- - +from collections import OrderedDict import json -from ssl import create_default_context +from ssl import create_default_context, SSLContext +from typing import Optional, Union, List, Dict, Any from kafka import KafkaProducer from kafka.errors import NoBrokersAvailable, UnknownTopicOrPartitionError -from collections import OrderedDict -from parsedmarc.utils import human_timestamp_to_datetime +from parsedmarc.utils import human_timestamp_to_datetime from parsedmarc import __version__ from parsedmarc.log import logger @@ -15,31 +14,38 @@ class KafkaError(RuntimeError): """Raised when a Kafka error occurs""" + def __init__(self, message: Union[str, Exception]): + if isinstance(message, Exception): + message = repr(message) + super().__init__(f"Kafka Error: {message}") + return + class KafkaClient(object): - def __init__(self, kafka_hosts, ssl=False, username=None, password=None, ssl_context=None): + def __init__( + self, + kafka_hosts: List[str], + ssl: bool = False, + username: Optional[str] = None, + password: Optional[str] = None, + ssl_context: Optional[SSLContext] = None, + ): """ - Initializes the Kafka client Args: - kafka_hosts (list): A list of Kafka hostnames - (with optional port numbers) - ssl (bool): Use a SSL/TLS connection - username (str): An optional username - password (str): An optional password + kafka_hosts: A list of Kafka hostnames (with optional port numbers) + ssl: Use a SSL/TLS connection. This is implied `True` if `username` or `password` is supplied. + username: An optional username + password: An optional password ssl_context: SSL context options - Notes: - ``use_ssl=True`` is implied when a username or password are - supplied. - - When using Azure Event Hubs, the username is literally - ``$ConnectionString``, and the password is the + Note: + When using Azure Event Hubs, the `username` is literally `$ConnectionString`, and the `password` is the Azure Event Hub connection string. """ config = dict( value_serializer=lambda v: json.dumps(v).encode("utf-8"), bootstrap_servers=kafka_hosts, - client_id="parsedmarc-{0}".format(__version__), + client_id=f"parsedmarc-{__version__}", ) if ssl or username or password: config["security_protocol"] = "SSL" @@ -81,23 +87,24 @@ def generate_daterange(report): begin_date_human = begin_date.strftime("%Y-%m-%dT%H:%M:%S") end_date_human = end_date.strftime("%Y-%m-%dT%H:%M:%S") date_range = [begin_date_human, end_date_human] - logger.debug("date_range is {}".format(date_range)) + logger.debug(f"date_range is {date_range}") return date_range - def save_aggregate_reports_to_kafka(self, aggregate_reports, aggregate_topic): + def save_aggregate_reports_to_kafka( + self, aggregate_reports: Union[OrderedDict, List[OrderedDict]], aggregate_topic: str + ) -> None: """ Saves aggregate DMARC reports to Kafka Args: - aggregate_reports (list): A list of aggregate report dictionaries - to save to Kafka - aggregate_topic (str): The name of the Kafka topic + aggregate_reports: Aggregate reports to save to Kafka + aggregate_topic: The name of the Kafka topic """ - if isinstance(aggregate_reports, dict) or isinstance(aggregate_reports, OrderedDict): + if isinstance(aggregate_reports, dict): aggregate_reports = [aggregate_reports] - if len(aggregate_reports) < 1: + if not aggregate_reports: return for report in aggregate_reports: @@ -115,40 +122,43 @@ def save_aggregate_reports_to_kafka(self, aggregate_reports, aggregate_topic): logger.debug("Saving aggregate report to Kafka") self.producer.send(aggregate_topic, slice) except UnknownTopicOrPartitionError: - raise KafkaError("Kafka error: Unknown topic or partition on broker") + raise KafkaError("Unknown topic or partition on broker") except Exception as e: - raise KafkaError("Kafka error: {0}".format(e.__str__())) + raise KafkaError(e) try: self.producer.flush() except Exception as e: - raise KafkaError("Kafka error: {0}".format(e.__str__())) + raise KafkaError(e) + return - def save_forensic_reports_to_kafka(self, forensic_reports, forensic_topic): + def save_forensic_reports_to_kafka( + self, forensic_reports: Union[OrderedDict, List[OrderedDict]], forensic_topic: str + ) -> None: """ Saves forensic DMARC reports to Kafka, sends individual records (slices) since Kafka requires messages to be <= 1MB by default. Args: - forensic_reports (list): A list of forensic report dicts - to save to Kafka - forensic_topic (str): The name of the Kafka topic + forensic_reports: Forensic reports to save to Kafka + forensic_topic: The name of the Kafka topic """ if isinstance(forensic_reports, dict): forensic_reports = [forensic_reports] - if len(forensic_reports) < 1: + if not forensic_reports: return try: logger.debug("Saving forensic reports to Kafka") self.producer.send(forensic_topic, forensic_reports) except UnknownTopicOrPartitionError: - raise KafkaError("Kafka error: Unknown topic or partition on broker") + raise KafkaError("Unknown topic or partition on broker") except Exception as e: - raise KafkaError("Kafka error: {0}".format(e.__str__())) + raise KafkaError(e) try: self.producer.flush() except Exception as e: - raise KafkaError("Kafka error: {0}".format(e.__str__())) + raise KafkaError(e) + return diff --git a/parsedmarc/mail/gmail.py b/parsedmarc/mail/gmail.py index 09bdbd1c..1eed6884 100644 --- a/parsedmarc/mail/gmail.py +++ b/parsedmarc/mail/gmail.py @@ -52,6 +52,15 @@ def __init__( reports_folder: str, oauth2_port: int, ): + """ + Args: + token_file: + credentials_file: + scopes: + include_spam_trash: + reports_folder: + oauth2_port: + """ creds = _get_creds(token_file, credentials_file, scopes, oauth2_port) self.service = build("gmail", "v1", credentials=creds) self.include_spam_trash = include_spam_trash diff --git a/parsedmarc/mail/graph.py b/parsedmarc/mail/graph.py index 07a383e5..ad1c6bae 100644 --- a/parsedmarc/mail/graph.py +++ b/parsedmarc/mail/graph.py @@ -104,6 +104,18 @@ def __init__( token_file: str, allow_unencrypted_storage: bool, ): + """ + Args: + auth_method: + mailbox: + client_id: + client_secret: + username: + password: + tenant_id: + token_file: + allow_unencrypted_storage: + """ token_path = Path(token_file) credential = _generate_credential( auth_method, diff --git a/parsedmarc/mail/imap.py b/parsedmarc/mail/imap.py index 81901a03..f41a1c26 100644 --- a/parsedmarc/mail/imap.py +++ b/parsedmarc/mail/imap.py @@ -1,4 +1,5 @@ from time import sleep +from typing import Optional from imapclient.exceptions import IMAPClientError from mailsuite.imap import IMAPClient @@ -13,15 +14,26 @@ class IMAPConnection(MailboxConnection): def __init__( self, - host=None, - user=None, - password=None, - port=None, - ssl=True, - verify=True, - timeout=30, - max_retries=4, + host: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + port: Optional[int] = None, + ssl: bool = True, + verify: bool = True, + timeout: int = 30, + max_retries: int = 4, ): + """ + Args: + host: Host to connect to + user: + password: + port: Port to connect to + ssl: Use SSL/TLS + verify: Verify the SSL/TLS certification + timeout: + max_retries: + """ self._username = user self._password = password self._verify = verify @@ -80,5 +92,5 @@ def idle_callback_wrapper(client: IMAPClient): logger.warning("IMAP connection timeout. Reconnecting...") sleep(check_timeout) except Exception as e: - logger.warning("IMAP connection error. {0}. " "Reconnecting...".format(e)) + logger.warning(f"IMAP connection error. {e!r}. " "Reconnecting...") sleep(check_timeout) diff --git a/parsedmarc/s3.py b/parsedmarc/s3.py index 0b0fd2b5..15df0420 100644 --- a/parsedmarc/s3.py +++ b/parsedmarc/s3.py @@ -1,6 +1,6 @@ -# -*- coding: utf-8 -*- - import json +from typing import Optional, Dict, Any + import boto3 from parsedmarc.log import logger @@ -11,17 +11,22 @@ class S3Client(object): """A client for a Amazon S3""" def __init__( - self, bucket_name, bucket_path, region_name, endpoint_url, access_key_id, secret_access_key + self, + bucket_name: str, + bucket_path: str, + region_name: Optional[str] = None, + endpoint_url: Optional[str] = None, + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, ): """ - Initializes the S3Client Args: - bucket_name (str): The S3 Bucket - bucket_path (str): The path to save reports - region_name (str): The region name - endpoint_url (str): The endpoint URL - access_key_id (str): The access key id - secret_access_key (str): The secret access key + bucket_name: The S3 Bucket + bucket_path: The path to save reports + region_name: The region name + endpoint_url: The endpoint URL + access_key_id: The access key id + secret_access_key: The secret access key """ self.bucket_name = bucket_name self.bucket_path = bucket_path @@ -42,14 +47,17 @@ def __init__( aws_secret_access_key=secret_access_key, ) self.bucket = self.s3.Bucket(self.bucket_name) + return - def save_aggregate_report_to_s3(self, report): + def save_aggregate_report_to_s3(self, report: Dict[str, Any]) -> None: self.save_report_to_s3(report, "aggregate") + return - def save_forensic_report_to_s3(self, report): + def save_forensic_report_to_s3(self, report: Dict[str, Any]) -> None: self.save_report_to_s3(report, "forensic") + return - def save_report_to_s3(self, report, report_type): + def save_report_to_s3(self, report: Dict[str, Any], report_type: str): report_date = human_timestamp_to_datetime(report["report_metadata"]["begin_date"]) report_id = report["report_metadata"]["report_id"] path_template = "{0}/{1}/year={2}/month={3:02d}/day={4:02d}/{5}.json" @@ -61,9 +69,7 @@ def save_report_to_s3(self, report, report_type): report_date.day, report_id, ) - logger.debug( - "Saving {0} report to s3://{1}/{2}".format(report_type, self.bucket_name, object_path) - ) + logger.debug("Saving {report_type} report to s3://{self.bucket_name}/{object_path}") object_metadata = { k: v for k, v in report["report_metadata"].items() if k in self.metadata_keys } diff --git a/parsedmarc/splunk.py b/parsedmarc/splunk.py index 9b316e23..746d3acf 100644 --- a/parsedmarc/splunk.py +++ b/parsedmarc/splunk.py @@ -1,6 +1,7 @@ from urllib.parse import urlparse import socket import json +from typing import Union, Optional, Dict, List, Any import urllib3 import requests @@ -15,6 +16,12 @@ class SplunkError(RuntimeError): """Raised when a Splunk API error occurs""" + def __init__(self, message: Union[str, Exception]): + if isinstance(message, Exception): + message = repr(message) + super().__init__(f"Splunk Error: {message}") + return + class HECClient(object): """A client for a Splunk HTTP Events Collector (HEC)""" @@ -22,21 +29,26 @@ class HECClient(object): # http://docs.splunk.com/Documentation/Splunk/latest/Data/AboutHEC # http://docs.splunk.com/Documentation/Splunk/latest/RESTREF/RESTinput#services.2Fcollector - def __init__(self, url, access_token, index, source="parsedmarc", verify=True, timeout=60): + def __init__( + self, + url: str, + access_token: str, + index: str, + source: str = "parsedmarc", + verify: bool = True, + timeout: int = 60, + ): """ - Initializes the HECClient - Args: - url (str): The URL of the HEC - access_token (str): The HEC access token - index (str): The name of the index - source (str): The source name - verify (bool): Verify SSL certificates - timeout (float): Number of seconds to wait for the server to send - data before giving up + url: The URL of the HEC + access_token: The HEC access token + index: The name of the index + source: The source name + verify: Verify SSL certificates + timeout: Number of seconds to wait for the server to send data before giving up """ - url = urlparse(url) - self.url = "{0}://{1}/services/collector/event/1.0".format(url.scheme, url.netloc) + parsed = urlparse(url) + self.url = f"{parsed.scheme}://{parsed.netloc}/services/collector/event/1.0" self.access_token = access_token.lstrip("Splunk ") self.index = index self.host = socket.getfqdn() @@ -44,27 +56,29 @@ def __init__(self, url, access_token, index, source="parsedmarc", verify=True, t self.session = requests.Session() self.timeout = timeout self.session.verify = verify - self._common_data = dict(host=self.host, source=self.source, index=self.index) + self._common_data: Dict[str, Any] = dict( + host=self.host, source=self.source, index=self.index + ) self.session.headers = { - "User-Agent": "parsedmarc/{0}".format(__version__), - "Authorization": "Splunk {0}".format(self.access_token), + "User-Agent": f"parsedmarc/{__version__}", + "Authorization": f"Splunk {self.access_token}", } + return - def save_aggregate_reports_to_splunk(self, aggregate_reports): - """ - Saves aggregate DMARC reports to Splunk + def save_aggregate_reports_to_splunk( + self, aggregate_reports: Union[Dict, List[Dict[str, Any]]] + ): + """Save aggregate DMARC reports to Splunk Args: - aggregate_reports: A list of aggregate report dictionaries - to save in Splunk - + aggregate_reports: Aggregate reports to save in Splunk """ logger.debug("Saving aggregate reports to Splunk") if isinstance(aggregate_reports, dict): aggregate_reports = [aggregate_reports] - if len(aggregate_reports) < 1: + if not aggregate_reports: return data = self._common_data.copy() @@ -95,31 +109,29 @@ def save_aggregate_reports_to_splunk(self, aggregate_reports): timestamp = human_timestamp_to_timestamp(new_report["begin_date"]) data["time"] = timestamp data["event"] = new_report.copy() - json_str += "{0}\n".format(json.dumps(data)) + json_str += json.dumps(data) + "\n" if not self.session.verify: logger.debug("Skipping certificate verification for Splunk HEC") try: - response = self.session.post(self.url, data=json_str, timeout=self.timeout) - response = response.json() + response = self.session.post(self.url, data=json_str, timeout=self.timeout).json() except Exception as e: - raise SplunkError(e.__str__()) + raise SplunkError(e) if response["code"] != 0: raise SplunkError(response["text"]) + return - def save_forensic_reports_to_splunk(self, forensic_reports): - """ - Saves forensic DMARC reports to Splunk + def save_forensic_reports_to_splunk(self, forensic_reports: Union[Dict, List[Dict[str, Any]]]): + """Save forensic DMARC reports to Splunk Args: - forensic_reports (list): A list of forensic report dictionaries - to save in Splunk + forensic_reports: Forensic reports to save in Splunk """ logger.debug("Saving forensic reports to Splunk") if isinstance(forensic_reports, dict): forensic_reports = [forensic_reports] - if len(forensic_reports) < 1: + if not forensic_reports: return json_str = "" @@ -129,14 +141,14 @@ def save_forensic_reports_to_splunk(self, forensic_reports): timestamp = human_timestamp_to_timestamp(report["arrival_date_utc"]) data["time"] = timestamp data["event"] = report.copy() - json_str += "{0}\n".format(json.dumps(data)) + json_str += json.dumps(data) + "\n" if not self.session.verify: logger.debug("Skipping certificate verification for Splunk HEC") try: - response = self.session.post(self.url, data=json_str, timeout=self.timeout) - response = response.json() + response = self.session.post(self.url, data=json_str, timeout=self.timeout).json() except Exception as e: - raise SplunkError(e.__str__()) + raise SplunkError(e) if response["code"] != 0: raise SplunkError(response["text"]) + return diff --git a/parsedmarc/syslog.py b/parsedmarc/syslog.py index d01785ef..2d56eb9e 100644 --- a/parsedmarc/syslog.py +++ b/parsedmarc/syslog.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import logging import logging.handlers import json @@ -10,12 +8,11 @@ class SyslogClient(object): """A client for Syslog""" - def __init__(self, server_name, server_port): + def __init__(self, server_name: str, server_port: int): """ - Initializes the SyslogClient Args: - server_name (str): The Syslog server - server_port (int): The Syslog UDP port + server_name: The Syslog server + server_port: The Syslog UDP port """ self.server_name = server_name self.server_port = server_port diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index fa7f07a1..cda3183f 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -119,7 +119,7 @@ def query_dns( """ domain = str(domain).lower() record_type = record_type.upper() - cache_key = "{0}_{1}".format(domain, record_type) + cache_key = f"{domain}_{record_type}" if cache: records = cache.get(cache_key, None) if records: @@ -191,13 +191,13 @@ def get_reverse_dns( def timestamp_to_datetime(timestamp: int) -> datetime: - """Converts a UNIX/DMARC timestamp to a Python ``datetime`` object + """Converts a UNIX/DMARC timestamp to a Python `datetime` object Args: timestamp: The timestamp Returns: - The converted timestamp as a Python ``datetime`` object + The converted timestamp as a Python `datetime` object """ return datetime.fromtimestamp(int(timestamp)) @@ -209,13 +209,13 @@ def timestamp_to_human(timestamp: int) -> str: timestamp: The timestamp Returns: - The converted timestamp in ``YYYY-MM-DD HH:MM:SS`` format + The converted timestamp in `YYYY-MM-DD HH:MM:SS` format """ return timestamp_to_datetime(timestamp).strftime("%Y-%m-%d %H:%M:%S") def human_timestamp_to_datetime(human_timestamp: str, to_utc: bool = False) -> datetime: - """Converts a human-readable timestamp into a Python ``datetime`` object + """Converts a human-readable timestamp into a Python `datetime` object Args: human_timestamp: A timestamp string @@ -236,7 +236,7 @@ def human_timestamp_to_timestamp(human_timestamp: str) -> float: """Converts a human-readable timestamp into a UNIX timestamp Args: - human_timestamp: A timestamp in `YYYY-MM-DD HH:MM:SS`` format + human_timestamp: A timestamp in `YYYY-MM-DD HH:MM:SS` format Returns: The converted timestamp @@ -272,9 +272,7 @@ def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Op if os.path.isfile(db_path) is False: db_path = None logger.warning( - f"No file exists at {db_path}. Falling back to an " - "included copy of the IPDB IP to Country " - "Lite database." + f"No file exists at {db_path}. Falling back to an included copy of the IPDB IP to Country Lite database." ) if db_path is None: @@ -404,7 +402,7 @@ def is_mbox(path: str) -> bool: if len(mbox.keys()) > 0: _is_mbox = True except Exception as e: - logger.debug("Error checking for MBOX file: {0}".format(e.__str__())) + logger.debug(f"Error checking for MBOX file: {e!r}") return _is_mbox @@ -424,7 +422,7 @@ def is_outlook_msg(content: Any) -> bool: def convert_outlook_msg(msg_bytes: bytes) -> bytes: """Convert an Outlook MS file to standard RFC 822 format - Requires the ``msgconvert`` Perl utility to be installed. + Requires the `msgconvert` Perl utility to be installed. Args: msg_bytes: the content of the .msg file @@ -535,7 +533,7 @@ def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False payload = str.encode(payload) attachment["sha256"] = hashlib.sha256(payload).hexdigest() except Exception as e: - logger.debug("Unable to decode attachment: {0}".format(e.__str__())) + logger.debug(f"Unable to decode attachment: {e!r}") if strip_attachment_payloads: for attachment in parsed_email["attachments"]: if "payload" in attachment: diff --git a/tests.py b/tests.py index 1f311ecf..79ec8e5f 100644 --- a/tests.py +++ b/tests.py @@ -16,6 +16,8 @@ def testBase64Decoding(self): decoded_str = parsedmarc.utils.decode_base64(b64_str) assert decoded_str == b"any carnal pleas" + return + def testPSLDownload(self): subdomain = "foo.example.com" result = parsedmarc.utils.get_base_domain(subdomain) @@ -26,6 +28,8 @@ def testPSLDownload(self): result = parsedmarc.utils.get_base_domain(subdomain) assert result == "c.akamaiedge.net" + return + def testAggregateSamples(self): """Test sample aggregate/rua DMARC reports""" print() @@ -33,16 +37,18 @@ def testAggregateSamples(self): for sample_path in sample_paths: if os.path.isdir(sample_path): continue - print("Testing {0}: " .format(sample_path), end="") - parsed_report = parsedmarc.parse_report_file( - sample_path)["report"] + print(f"Testing {sample_path}: ", end="") + parsed_report = parsedmarc.parse_report_file(sample_path)["report"] parsedmarc.parsed_aggregate_reports_to_csv(parsed_report) print("Passed!") + return + def testEmptySample(self): """Test empty/unparasable report""" with self.assertRaises(parsedmarc.InvalidDMARCReport): - parsedmarc.parse_report_file('samples/empty.xml') + parsedmarc.parse_report_file("samples/empty.xml") + return def testForensicSamples(self): """Test sample forensic/ruf/failure DMARC reports""" @@ -52,12 +58,11 @@ def testForensicSamples(self): print("Testing {0}: ".format(sample_path), end="") with open(sample_path) as sample_file: sample_content = sample_file.read() - parsed_report = parsedmarc.parse_report_email( - sample_content)["report"] - parsed_report = parsedmarc.parse_report_file( - sample_path)["report"] + parsed_report = parsedmarc.parse_report_email(sample_content)["report"] + parsed_report = parsedmarc.parse_report_file(sample_path)["report"] parsedmarc.parsed_forensic_reports_to_csv(parsed_report) print("Passed!") + return if __name__ == "__main__": From 306479f8358a41b37c561937c0c900e328037271 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Wed, 3 Jan 2024 15:06:37 +1100 Subject: [PATCH 05/15] Fix linting, add dev dependencies to pyproject.toml - Run black -l 100 -t py37 - Run isort --- .flake8 | 2 ++ parsedmarc/__init__.py | 36 ++++++++++++-------- parsedmarc/cli.py | 44 +++++++++++++------------ parsedmarc/elastic.py | 23 +++++++------ parsedmarc/kafkaclient.py | 9 +++-- parsedmarc/log.py | 1 + parsedmarc/loganalytics.py | 5 ++- parsedmarc/mail/__init__.py | 5 +-- parsedmarc/mail/gmail.py | 3 ++ parsedmarc/mail/graph.py | 11 ++++--- parsedmarc/mail/imap.py | 5 ++- parsedmarc/mail/mailbox_connection.py | 1 + parsedmarc/s3.py | 5 ++- parsedmarc/splunk.py | 11 ++++--- parsedmarc/syslog.py | 9 +++-- parsedmarc/utils.py | 44 ++++++++++++------------- pyproject.toml | 47 +++++++++++++++++++++++++-- 17 files changed, 172 insertions(+), 89 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..6deafc26 --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 120 diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index fd8bbec8..39958210 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -1,36 +1,44 @@ """A Python package for parsing DMARC reports""" +# Standard Library +from base64 import b64decode import binascii +from collections import OrderedDict +from csv import DictWriter +from datetime import datetime import email import email.utils +from io import BytesIO, StringIO import json import mailbox import os import re import shutil import tempfile +from typing import Any, BinaryIO, Callable, Dict, List, Optional, Union, cast import xml.parsers.expat as expat import zipfile import zlib -from base64 import b64decode -from collections import OrderedDict -from csv import DictWriter -from datetime import datetime -from io import BytesIO, StringIO -from typing import List, Dict, Any, Optional, Union, Callable, BinaryIO, cast -import mailparser -import xmltodict +# Installed from expiringdict import ExpiringDict from lxml import etree +import mailparser from mailsuite.smtp import send_email +import xmltodict +# Package from parsedmarc.log import logger from parsedmarc.mail import MailboxConnection -from parsedmarc.utils import get_base_domain, get_ip_address_info -from parsedmarc.utils import is_outlook_msg, convert_outlook_msg -from parsedmarc.utils import parse_email -from parsedmarc.utils import timestamp_to_human, human_timestamp_to_datetime +from parsedmarc.utils import ( + convert_outlook_msg, + get_base_domain, + get_ip_address_info, + human_timestamp_to_datetime, + is_outlook_msg, + parse_email, + timestamp_to_human, +) __version__ = "8.6.4" @@ -758,7 +766,7 @@ def parse_forensic_report( raise InvalidForensicReport(f"Missing value: {error!r}") except Exception as error: - raise InvalidForensicReport("Unexpected error: {error!r}") + raise InvalidForensicReport(f"Unexpected error: {error!r}") def parsed_forensic_reports_to_csv_rows( @@ -1299,7 +1307,7 @@ def watch_inbox( archive_folder: The folder to move processed mail to delete: Delete messages after processing them test: Do not move or delete messages after processing them - check_timeout: Number of seconds to wait for a IMAP IDLE response or the number of seconds until the next mail check + check_timeout: Number of seconds to wait for a IMAP IDLE response or the next mail check ip_db_path: Path to a MMDB file from MaxMind or DBIP offline: Do not query online for geolocation or DNS nameservers: A list of one or more nameservers to use (Cloudflare's public DNS resolvers by default) diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index 587b4663..eed7afc2 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -1,42 +1,44 @@ #!/usr/bin/env python3 """A CLI for parsing DMARC reports""" -from argparse import Namespace, ArgumentParser -import os +# Standard Library +from argparse import ArgumentParser, Namespace +from collections import OrderedDict from configparser import ConfigParser from glob import glob -import logging -from collections import OrderedDict +from itertools import repeat import json -from ssl import CERT_NONE, create_default_context +import logging from multiprocessing import Pool, Value -from itertools import repeat +import os +from ssl import CERT_NONE, create_default_context import sys import time + +# Installed from tqdm import tqdm +# Package from parsedmarc import ( + InvalidDMARCReport, + ParserError, + __version__, + elastic, + email_results, get_dmarc_reports_from_mailbox, - watch_inbox, - parse_report_file, get_dmarc_reports_from_mbox, - elastic, kafkaclient, - splunk, - save_output, - email_results, - ParserError, - __version__, - InvalidDMARCReport, + loganalytics, + parse_report_file, s3, + save_output, + splunk, syslog, - loganalytics, + watch_inbox, ) - -from parsedmarc.mail import IMAPConnection, MSGraphConnection, GmailConnection -from parsedmarc.mail.graph import AuthMethod - from parsedmarc.log import logger +from parsedmarc.mail import GmailConnection, IMAPConnection, MSGraphConnection +from parsedmarc.mail.graph import AuthMethod from parsedmarc.utils import is_mbox formatter = logging.Formatter( @@ -856,7 +858,7 @@ def process_reports(reports_): ssl_context=ssl_context, ) except Exception as error_: - logger.error(f"Kafka Error: {error!r}") + logger.error(f"Kafka Error: {error_!r}") kafka_aggregate_topic = opts.kafka_aggregate_topic kafka_forensic_topic = opts.kafka_forensic_topic diff --git a/parsedmarc/elastic.py b/parsedmarc/elastic.py index 6cd2b094..a2b2322e 100644 --- a/parsedmarc/elastic.py +++ b/parsedmarc/elastic.py @@ -1,26 +1,29 @@ +# Standard Library from collections import OrderedDict -from typing import Optional, Union, List, Dict, Any +from typing import Any, Dict, List, Optional, Union -from elasticsearch_dsl.search import Q +# Installed +from elasticsearch.helpers import reindex from elasticsearch_dsl import ( - connections, - Object, + Boolean, + Date, Document, Index, - Nested, InnerDoc, Integer, - Text, - Boolean, Ip, - Date, + Nested, + Object, Search, + Text, + connections, ) -from elasticsearch.helpers import reindex +from elasticsearch_dsl.search import Q +# Package +from parsedmarc import InvalidForensicReport from parsedmarc.log import logger from parsedmarc.utils import human_timestamp_to_datetime -from parsedmarc import InvalidForensicReport class ElasticsearchError(Exception): diff --git a/parsedmarc/kafkaclient.py b/parsedmarc/kafkaclient.py index 61626a1d..843e6a37 100644 --- a/parsedmarc/kafkaclient.py +++ b/parsedmarc/kafkaclient.py @@ -1,14 +1,17 @@ +# Standard Library from collections import OrderedDict import json -from ssl import create_default_context, SSLContext -from typing import Optional, Union, List, Dict, Any +from ssl import SSLContext, create_default_context +from typing import List, Optional, Union +# Installed from kafka import KafkaProducer from kafka.errors import NoBrokersAvailable, UnknownTopicOrPartitionError -from parsedmarc.utils import human_timestamp_to_datetime +# Package from parsedmarc import __version__ from parsedmarc.log import logger +from parsedmarc.utils import human_timestamp_to_datetime class KafkaError(RuntimeError): diff --git a/parsedmarc/log.py b/parsedmarc/log.py index c10988db..2c9f1cf2 100644 --- a/parsedmarc/log.py +++ b/parsedmarc/log.py @@ -1,3 +1,4 @@ +# Standard Library import logging logger = logging.getLogger(__name__) diff --git a/parsedmarc/loganalytics.py b/parsedmarc/loganalytics.py index 15686921..eb083252 100644 --- a/parsedmarc/loganalytics.py +++ b/parsedmarc/loganalytics.py @@ -1,9 +1,12 @@ -from typing import List, Dict, Optional +# Standard Library +from typing import Dict, List, Optional +# Installed from azure.core.exceptions import HttpResponseError from azure.identity import ClientSecretCredential from azure.monitor.ingestion import LogsIngestionClient +# Package from parsedmarc.log import logger diff --git a/parsedmarc/mail/__init__.py b/parsedmarc/mail/__init__.py index 5d40b3ad..3bbe4e8b 100644 --- a/parsedmarc/mail/__init__.py +++ b/parsedmarc/mail/__init__.py @@ -1,6 +1,7 @@ -from parsedmarc.mail.mailbox_connection import MailboxConnection -from parsedmarc.mail.graph import MSGraphConnection +# Package from parsedmarc.mail.gmail import GmailConnection +from parsedmarc.mail.graph import MSGraphConnection from parsedmarc.mail.imap import IMAPConnection +from parsedmarc.mail.mailbox_connection import MailboxConnection __all__ = ["MailboxConnection", "MSGraphConnection", "GmailConnection", "IMAPConnection"] diff --git a/parsedmarc/mail/gmail.py b/parsedmarc/mail/gmail.py index 1eed6884..46884882 100644 --- a/parsedmarc/mail/gmail.py +++ b/parsedmarc/mail/gmail.py @@ -1,9 +1,11 @@ +# Standard Library from base64 import urlsafe_b64decode from functools import lru_cache from pathlib import Path from time import sleep from typing import TYPE_CHECKING, List +# Installed from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials from google_auth_oauthlib.flow import InstalledAppFlow @@ -14,6 +16,7 @@ # https://github.com/henribru/google-api-python-client-stubs?tab=readme-ov-file#explicit-annotations from googleapiclient._apis.gmail.v1.schemas import ModifyMessageRequest, Label +# Package from parsedmarc.log import logger from parsedmarc.mail.mailbox_connection import MailboxConnection diff --git a/parsedmarc/mail/graph.py b/parsedmarc/mail/graph.py index ad1c6bae..33b70984 100644 --- a/parsedmarc/mail/graph.py +++ b/parsedmarc/mail/graph.py @@ -1,18 +1,21 @@ +# Standard Library from enum import Enum from functools import lru_cache from pathlib import Path from time import sleep -from typing import Dict, List, Optional, Union, Any +from typing import Any, Dict, List, Optional, Union +# Installed from azure.identity import ( - UsernamePasswordCredential, - DeviceCodeCredential, + AuthenticationRecord, ClientSecretCredential, + DeviceCodeCredential, TokenCachePersistenceOptions, - AuthenticationRecord, + UsernamePasswordCredential, ) from msgraph.core import GraphClient +# Package from parsedmarc.log import logger from parsedmarc.mail.mailbox_connection import MailboxConnection diff --git a/parsedmarc/mail/imap.py b/parsedmarc/mail/imap.py index f41a1c26..cb1c5016 100644 --- a/parsedmarc/mail/imap.py +++ b/parsedmarc/mail/imap.py @@ -1,10 +1,13 @@ +# Standard Library +from socket import timeout from time import sleep from typing import Optional +# Installed from imapclient.exceptions import IMAPClientError from mailsuite.imap import IMAPClient -from socket import timeout +# Package from parsedmarc.log import logger from parsedmarc.mail.mailbox_connection import MailboxConnection diff --git a/parsedmarc/mail/mailbox_connection.py b/parsedmarc/mail/mailbox_connection.py index 4f5d1b19..a36cbbdc 100644 --- a/parsedmarc/mail/mailbox_connection.py +++ b/parsedmarc/mail/mailbox_connection.py @@ -1,3 +1,4 @@ +# Standard Library from abc import ABC from typing import List, Union diff --git a/parsedmarc/s3.py b/parsedmarc/s3.py index 15df0420..1c7181da 100644 --- a/parsedmarc/s3.py +++ b/parsedmarc/s3.py @@ -1,8 +1,11 @@ +# Standard Library import json -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional +# Installed import boto3 +# Package from parsedmarc.log import logger from parsedmarc.utils import human_timestamp_to_datetime diff --git a/parsedmarc/splunk.py b/parsedmarc/splunk.py index 746d3acf..59066ec7 100644 --- a/parsedmarc/splunk.py +++ b/parsedmarc/splunk.py @@ -1,11 +1,14 @@ -from urllib.parse import urlparse -import socket +# Standard Library import json -from typing import Union, Optional, Dict, List, Any +import socket +from typing import Any, Dict, List, Union +from urllib.parse import urlparse -import urllib3 +# Installed import requests +import urllib3 +# Package from parsedmarc import __version__ from parsedmarc.log import logger from parsedmarc.utils import human_timestamp_to_timestamp diff --git a/parsedmarc/syslog.py b/parsedmarc/syslog.py index 2d56eb9e..41021ff8 100644 --- a/parsedmarc/syslog.py +++ b/parsedmarc/syslog.py @@ -1,8 +1,13 @@ +# Standard Library +import json import logging import logging.handlers -import json -from parsedmarc import parsed_aggregate_reports_to_csv_rows, parsed_forensic_reports_to_csv_rows +# Package +from parsedmarc import ( + parsed_aggregate_reports_to_csv_rows, + parsed_forensic_reports_to_csv_rows, +) class SyslogClient(object): diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index cda3183f..f3ddd872 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -1,42 +1,37 @@ """Utility functions that might be useful for other projects""" -import logging -import os -from datetime import datetime -from datetime import timezone -from datetime import timedelta +# Standard Library +import atexit +import base64 from collections import OrderedDict -import tempfile -import subprocess -import shutil -import mailparser -import json +from datetime import datetime, timedelta, timezone import hashlib -import base64 -import atexit +import importlib.resources +import json +import logging import mailbox +import os import re -from typing import List, Dict, Any, Optional, Union - -try: - import importlib.resources as pkg_resources -except ImportError: - # Try backported to PY<37 `importlib_resources` - import importlib_resources as pkg_resources # type: ignore[no-redef] +import shutil +import subprocess +import tempfile +from typing import Any, Dict, List, Optional, Union +# Installed from dateutil.parser import parse as parse_date -import dns.reversename -import dns.resolver import dns.exception +import dns.resolver +import dns.reversename from expiringdict import ExpiringDict import geoip2.database import geoip2.errors +import mailparser import publicsuffixlist +# Package from parsedmarc.log import logger import parsedmarc.resources.dbip - parenthesis_regex = re.compile(r"\s*\(.*\)\s*") null_file = open(os.devnull, "w") @@ -272,7 +267,8 @@ def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Op if os.path.isfile(db_path) is False: db_path = None logger.warning( - f"No file exists at {db_path}. Falling back to an included copy of the IPDB IP to Country Lite database." + f"No file exists at {db_path}. " + "Falling back to an included copy of the IPDB IP to Country Lite database." ) if db_path is None: @@ -282,7 +278,7 @@ def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Op break if db_path is None: - with pkg_resources.path(parsedmarc.resources.dbip, "dbip-country-lite.mmdb") as path: + with importlib.resources.path(parsedmarc.resources.dbip, "dbip-country-lite.mmdb") as path: db_path = str(path) db_age = datetime.now() - datetime.fromtimestamp(os.stat(db_path).st_mtime) diff --git a/pyproject.toml b/pyproject.toml index 261c6467..37f22a0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dynamic = [ ] description = "A Python package and CLI for parsing aggregate and forensic DMARC reports" readme = "README.md" -license = "Apache-2.0" +license = {text = "Apache-2.0"} authors = [ { name = "Sean Whalen", email = "whalenster@gmail.com" }, ] @@ -26,7 +26,13 @@ classifiers = [ "Intended Audience :: Information Technology", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3" + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] requires-python = ">=3.7" @@ -58,6 +64,31 @@ dependencies = [ "xmltodict>=0.12.0", ] +[project.optional-dependencies] +dev = [ + ## Type Stubs + "google-api-python-client-stubs", + "google-auth-stubs", + "lxml-stubs", + "types-boto3", + "types-python-dateutil", + "types-requests", + "types-tqdm", + "types-xmltodict", + ### dev.sh dependencies + ## Formatting / Linting + "validate-pyproject[all]", + "black", + "isort", + #"pylint", + "mypy", + ## Testing + "flake8", + #"pytest", + ## REPL + "bpython", +] + [project.scripts] parsedmarc = "parsedmarc.cli:_main" @@ -71,3 +102,15 @@ path = "parsedmarc/__init__.py" include = [ "/parsedmarc", ] + +[tool.isort] +profile = "black" +force_sort_within_sections = true +src_paths = ["parsedmarc", "tests.py"] + +# Section Headings +import_heading_future = "Future" +import_heading_stdlib = "Standard Library" +import_heading_thirdparty = "Installed" +import_heading_firstparty = "Package" +import_heading_localfolder = "Local" From 08a1f204ef5df52af9352ca992b88642dabf8983 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Thu, 1 Feb 2024 02:24:28 +1100 Subject: [PATCH 06/15] Use annotations, move requirements to pyproject.toml --- .github/workflows/python-tests.yml | 2 +- parsedmarc/__init__.py | 132 ++++++++++++++------------ parsedmarc/cli.py | 20 +++- parsedmarc/elastic.py | 29 +++--- parsedmarc/kafkaclient.py | 19 ++-- parsedmarc/loganalytics.py | 13 +-- parsedmarc/mail/__init__.py | 7 +- parsedmarc/mail/gmail.py | 16 ++-- parsedmarc/mail/graph.py | 42 ++++---- parsedmarc/mail/imap.py | 11 ++- parsedmarc/mail/mailbox_connection.py | 7 +- parsedmarc/s3.py | 20 ++-- parsedmarc/splunk.py | 16 ++-- parsedmarc/syslog.py | 2 +- parsedmarc/utils.py | 40 +++++--- pyproject.toml | 24 ++++- requirements.txt | 37 -------- 17 files changed, 233 insertions(+), 204 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 375414b6..44b3e8aa 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -33,7 +33,7 @@ jobs: - name: Install Python dependencies run: | python -m pip install --upgrade pip - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -e '.[dev,docs]' - name: Test building documentation run: | cd docs diff --git a/parsedmarc/__init__.py b/parsedmarc/__init__.py index 39958210..4eab343c 100644 --- a/parsedmarc/__init__.py +++ b/parsedmarc/__init__.py @@ -1,5 +1,7 @@ """A Python package for parsing DMARC reports""" +from __future__ import annotations + # Standard Library from base64 import b64decode import binascii @@ -15,7 +17,7 @@ import re import shutil import tempfile -from typing import Any, BinaryIO, Callable, Dict, List, Optional, Union, cast +from typing import Any, BinaryIO, Callable, cast import xml.parsers.expat as expat import zipfile import zlib @@ -74,9 +76,9 @@ class InvalidForensicReport(InvalidDMARCReport): def _parse_report_record( record: OrderedDict, - ip_db_path: Optional[str] = None, + ip_db_path: str | None = None, offline: bool = False, - nameservers: Optional[List[str]] = None, + nameservers: list[str] | None = None, dns_timeout: float = 2.0, parallel: bool = False, ) -> OrderedDict: @@ -221,12 +223,12 @@ def _parse_report_record( def parse_aggregate_report_xml( xml: str, - ip_db_path: Optional[str] = None, + ip_db_path: str | None = None, offline: bool = False, - nameservers: Optional[List[str]] = None, + nameservers: list[str] | None = None, timeout: float = 2.0, parallel: bool = False, - keep_alive: Optional[Callable] = None, + keep_alive: Callable | None = None, ) -> OrderedDict[str, Any]: """Parses a DMARC XML report string and returns a consistent OrderedDict @@ -250,7 +252,8 @@ def parse_aggregate_report_xml( errors.append(f"Invalid XML: {e!r}") try: tree = etree.parse( - BytesIO(xml.encode("utf-8")), etree.XMLParser(recover=True, resolve_entities=False) + BytesIO(xml.encode("utf-8")), + etree.XMLParser(recover=True, resolve_entities=False), ) s = etree.tostring(tree) xml = "" if s is None else s.decode("utf-8") @@ -385,7 +388,7 @@ def parse_aggregate_report_xml( raise InvalidAggregateReport(f"Unexpected error: {error!r}") -def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: +def extract_xml(input_: str | bytes | BinaryIO) -> str: """Extracts xml from a zip or gzip file at the given path, file-like object, or bytes. Args: @@ -435,13 +438,13 @@ def extract_xml(input_: Union[str, bytes, BinaryIO]) -> str: def parse_aggregate_report_file( - _input: Union[bytes, str, BinaryIO], + _input: bytes | str | BinaryIO, offline: bool = False, - ip_db_path: Optional[str] = None, - nameservers: Optional[List[str]] = None, + ip_db_path: str | None = None, + nameservers: list[str] | None = None, dns_timeout: float = 2.0, parallel: bool = False, - keep_alive: Optional[Callable] = None, + keep_alive: Callable | None = None, ) -> OrderedDict[str, Any]: """Parse a file at the given path, a file-like object. or bytes as an aggregate DMARC report @@ -472,8 +475,8 @@ def parse_aggregate_report_file( def parsed_aggregate_reports_to_csv_rows( - reports: Union[OrderedDict, List[OrderedDict]] -) -> List[Dict[str, Union[str, int, bool]]]: + reports: OrderedDict | list[OrderedDict], +) -> list[dict[str, str | int | bool]]: """Convert one or more parsed aggregate reports to list of dicts in flat CSV format Args: @@ -483,10 +486,10 @@ def parsed_aggregate_reports_to_csv_rows( Parsed aggregate report data as a list of dicts in flat CSV format """ - def to_str(obj): + def to_str(obj) -> str: return str(obj).lower() - if type(reports) is OrderedDict: + if not isinstance(reports, list): reports = [reports] rows = [] @@ -538,7 +541,10 @@ def to_str(obj): row["dmarc_aligned"] = record["alignment"]["dmarc"] row["disposition"] = record["policy_evaluated"]["disposition"] policy_override_reasons = list( - map(lambda r_: r_["type"], record["policy_evaluated"]["policy_override_reasons"]) + map( + lambda r_: r_["type"], + record["policy_evaluated"]["policy_override_reasons"], + ) ) policy_override_comments = list( map( @@ -583,7 +589,7 @@ def to_str(obj): return rows -def parsed_aggregate_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict]]) -> str: +def parsed_aggregate_reports_to_csv(reports: OrderedDict | list[OrderedDict]) -> str: """Convert one or more parsed aggregate reports to flat CSV format, including headers Args: @@ -649,8 +655,8 @@ def parse_forensic_report( sample: str, msg_date: datetime, offline: bool = False, - ip_db_path: Optional[str] = None, - nameservers: Optional[List[str]] = None, + ip_db_path: str | None = None, + nameservers: list[str] | None = None, dns_timeout: float = 2.0, strip_attachment_payloads: bool = False, parallel: bool = False, @@ -770,8 +776,8 @@ def parse_forensic_report( def parsed_forensic_reports_to_csv_rows( - reports: Union[OrderedDict, List[OrderedDict]] -) -> List[Dict[str, Any]]: + reports: OrderedDict | list[OrderedDict], +) -> list[dict[str, Any]]: """Convert one or more parsed forensic reports to a list of dicts in flat CSV format Args: @@ -803,7 +809,7 @@ def parsed_forensic_reports_to_csv_rows( return rows -def parsed_forensic_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict]]) -> str: +def parsed_forensic_reports_to_csv(reports: OrderedDict | list[OrderedDict]) -> str: """Convert one or more parsed forensic reports to flat CSV format, including headers Args: @@ -843,7 +849,7 @@ def parsed_forensic_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict] rows = parsed_forensic_reports_to_csv_rows(reports) for row in rows: - new_row: Dict[str, Any] = {} + new_row: dict[str, Any] = {} for key in new_row.keys(): new_row[key] = row[key] csv_writer.writerow(new_row) @@ -852,15 +858,15 @@ def parsed_forensic_reports_to_csv(reports: Union[OrderedDict, List[OrderedDict] def parse_report_email( - input_: Union[bytes, str], + input_: bytes | str, offline: bool = False, - ip_db_path: Optional[str] = None, - nameservers: Optional[List[str]] = None, + ip_db_path: str | None = None, + nameservers: list[str] | None = None, dns_timeout: float = 2.0, strip_attachment_payloads: bool = False, parallel: bool = False, - keep_alive: Optional[Callable] = None, -) -> OrderedDict[str, Union[str, OrderedDict]]: + keep_alive: Callable | None = None, +) -> OrderedDict[str, str | OrderedDict]: """Parse a DMARC report from an email Args: @@ -990,14 +996,14 @@ def parse_report_email( def parse_report_file( - input_: Union[str, bytes, BinaryIO], - nameservers: Optional[List[str]] = None, + input_: str | bytes | BinaryIO, + nameservers: list[str] | None = None, dns_timeout: float = 2.0, strip_attachment_payloads: bool = False, - ip_db_path: Optional[str] = None, + ip_db_path: str | None = None, offline: bool = False, parallel: bool = False, - keep_alive: Optional[Callable] = None, + keep_alive: Callable | None = None, ) -> OrderedDict: """Parse a DMARC aggregate or forensic file at the given path, a file-like object. or bytes @@ -1026,7 +1032,7 @@ def parse_report_file( content = file_object.read() file_object.close() - results: OrderedDict[str, Union[str, OrderedDict]] + results: OrderedDict[str, str | OrderedDict] try: report = parse_aggregate_report_file( content, @@ -1058,13 +1064,13 @@ def parse_report_file( def get_dmarc_reports_from_mbox( input_: str, - nameservers: Optional[List[str]] = None, + nameservers: list[str] | None = None, dns_timeout: float = 2.0, strip_attachment_payloads: bool = False, - ip_db_path: Optional[str] = None, + ip_db_path: str | None = None, offline: bool = False, parallel: bool = False, -) -> OrderedDict[str, List[OrderedDict]]: +) -> OrderedDict[str, list[OrderedDict]]: """Parses a mailbox in mbox format containing e-mails with attached DMARC reports Args: @@ -1079,8 +1085,8 @@ def get_dmarc_reports_from_mbox( Returns: Dictionary of Lists of ``aggregate_reports`` and ``forensic_reports`` """ - aggregate_reports: List[OrderedDict] = [] - forensic_reports: List[OrderedDict] = [] + aggregate_reports: list[OrderedDict] = [] + forensic_reports: list[OrderedDict] = [] try: mbox = mailbox.mbox(input_) message_keys = mbox.keys() @@ -1110,7 +1116,10 @@ def get_dmarc_reports_from_mbox( except mailbox.NoSuchMailboxError: raise InvalidDMARCReport(f"Mailbox {input_} does not exist") return OrderedDict( - [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] + [ + ("aggregate_reports", aggregate_reports), + ("forensic_reports", forensic_reports), + ] ) @@ -1120,15 +1129,15 @@ def get_dmarc_reports_from_mailbox( archive_folder: str = "Archive", delete: bool = False, test: bool = False, - ip_db_path: Optional[str] = None, + ip_db_path: str | None = None, offline: bool = False, - nameservers: Optional[List[str]] = None, + nameservers: list[str] | None = None, dns_timeout: float = 6.0, strip_attachment_payloads: bool = False, - results: Optional[OrderedDict[str, List[OrderedDict]]] = None, + results: OrderedDict[str, list[OrderedDict]] | None = None, batch_size: int = 10, create_folders: bool = True, -) -> OrderedDict[str, List[OrderedDict]]: +) -> OrderedDict[str, list[OrderedDict]]: """Fetches and parses DMARC reports from a mailbox Args: @@ -1259,7 +1268,10 @@ def get_dmarc_reports_from_mailbox( logger.error(f"Mailbox error: Error moving message UID {msg_uid}: {e!r}") results = OrderedDict( - [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] + [ + ("aggregate_reports", aggregate_reports), + ("forensic_reports", forensic_reports), + ] ) total_messages = len(connection.fetch_messages(reports_folder)) @@ -1291,12 +1303,12 @@ def watch_inbox( delete: bool = False, test: bool = False, check_timeout: int = 30, - ip_db_path: Optional[str] = None, + ip_db_path: str | None = None, offline: bool = False, - nameservers: Optional[List[str]] = None, + nameservers: list[str] | None = None, dns_timeout: float = 6.0, strip_attachment_payloads: bool = False, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> None: """Watches a mailbox for new messages and sends the results to a callback function @@ -1337,7 +1349,7 @@ def check_callback(connection): mailbox_connection.watch(check_callback=check_callback, check_timeout=check_timeout) -def append_json(filename: str, reports: List[OrderedDict]) -> None: +def append_json(filename: str, reports: list[OrderedDict]) -> None: with open(filename, "a+", newline="\n", encoding="utf-8") as output: output_json = json.dumps(reports, ensure_ascii=False, indent=2) if output.seek(0, os.SEEK_END) != 0: @@ -1375,7 +1387,7 @@ def append_csv(filename: str, csv: str) -> None: def save_output( - results: OrderedDict[str, List[OrderedDict]], + results: OrderedDict[str, list[OrderedDict]], output_directory: str = "output", aggregate_json_filename: str = "aggregate.json", forensic_json_filename: str = "forensic.json", @@ -1441,7 +1453,7 @@ def save_output( return -def get_report_zip(results: OrderedDict[str, List[OrderedDict]]) -> bytes: +def get_report_zip(results: OrderedDict[str, list[OrderedDict]]) -> bytes: """Creates a zip file of parsed report output Args: @@ -1486,20 +1498,20 @@ def add_subdir(root_path, subdir): def email_results( - results: OrderedDict[str, List[OrderedDict]], + results: OrderedDict[str, list[OrderedDict]], host: str, mail_from: str, - mail_to: List[str], - mail_cc: Optional[List[str]] = None, - mail_bcc: Optional[List[str]] = None, + mail_to: list[str], + mail_cc: list[str] | None = None, + mail_bcc: list[str] | None = None, port: int = 0, require_encryption: bool = False, verify: bool = True, - username: Optional[str] = None, - password: Optional[str] = None, - subject: Optional[str] = None, - attachment_filename: Optional[str] = None, - message: Optional[str] = None, + username: str | None = None, + password: str | None = None, + subject: str | None = None, + attachment_filename: str | None = None, + message: str | None = None, ) -> None: """Emails parsing results as a zip file diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index eed7afc2..d52e7ce2 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -42,7 +42,8 @@ from parsedmarc.utils import is_mbox formatter = logging.Formatter( - fmt="%(levelname)8s:%(filename)s:%(lineno)d:%(message)s", datefmt="%Y-%m-%d:%H:%M:%S" + fmt="%(levelname)8s:%(filename)s:%(lineno)d:%(message)s", + datefmt="%Y-%m-%d:%H:%M:%S", ) handler = logging.StreamHandler() handler.setFormatter(formatter) @@ -204,7 +205,9 @@ def process_reports(reports_): arg_parser = ArgumentParser(description="Parses DMARC reports") arg_parser.add_argument( - "-c", "--config-file", help="a path to a configuration file " "(--silent implied)" + "-c", + "--config-file", + help="a path to a configuration file " "(--silent implied)", ) arg_parser.add_argument( "file_path", @@ -251,7 +254,10 @@ def process_reports(reports_): ) arg_parser.add_argument("-s", "--silent", action="store_true", help="only print errors") arg_parser.add_argument( - "-w", "--warnings", action="store_true", help="print warnings in addition to errors" + "-w", + "--warnings", + action="store_true", + help="print warnings in addition to errors", ) arg_parser.add_argument("--verbose", action="store_true", help="more verbose output") arg_parser.add_argument("--debug", action="store_true", help="print debugging information") @@ -805,7 +811,8 @@ def process_reports(reports_): timeout=opts.elasticsearch_timeout, ) elastic.migrate_indexes( - aggregate_indexes=[es_aggregate_index], forensic_indexes=[es_forensic_index] + aggregate_indexes=[es_aggregate_index], + forensic_indexes=[es_forensic_index], ) except elastic.ElasticsearchError: logger.exception("Elasticsearch Error") @@ -1025,7 +1032,10 @@ def process_reports(reports_): exit(1) results = OrderedDict( - [("aggregate_reports", aggregate_reports), ("forensic_reports", forensic_reports)] + [ + ("aggregate_reports", aggregate_reports), + ("forensic_reports", forensic_reports), + ] ) process_reports(results) diff --git a/parsedmarc/elastic.py b/parsedmarc/elastic.py index a2b2322e..a79c7b5f 100644 --- a/parsedmarc/elastic.py +++ b/parsedmarc/elastic.py @@ -1,6 +1,8 @@ +from __future__ import annotations + # Standard Library from collections import OrderedDict -from typing import Any, Dict, List, Optional, Union +from typing import Any # Installed from elasticsearch.helpers import reindex @@ -29,7 +31,7 @@ class ElasticsearchError(Exception): """Raised when an Elasticsearch error occurs""" - def __init__(self, message: Union[str, Exception]): + def __init__(self, message: str | Exception): if isinstance(message, Exception): message = repr(message) super().__init__(f"Elasticsearch Error: {message}") @@ -181,12 +183,12 @@ class AlreadySaved(ValueError): def set_hosts( - hosts: Union[str, List[str]], + hosts: str | list[str], use_ssl: bool = False, - ssl_cert_path: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - apiKey: Optional[str] = None, + ssl_cert_path: str | None = None, + username: str | None = None, + password: str | None = None, + apiKey: str | None = None, timeout: float = 60.0, ) -> None: """Set the Elasticsearch host(s) to use @@ -218,7 +220,7 @@ def set_hosts( return -def create_indexes(names: List[str], settings: Optional[Dict[str, int]] = None) -> None: +def create_indexes(names: list[str], settings: dict[str, int] | None = None) -> None: """Create Elasticsearch indexes Args: @@ -242,7 +244,8 @@ def create_indexes(names: List[str], settings: Optional[Dict[str, int]] = None) def migrate_indexes( - aggregate_indexes: Optional[List[str]] = None, forensic_indexes: Optional[List[str]] = None + aggregate_indexes: list[str] | None = None, + forensic_indexes: list[str] | None = None, ): """Update index mappings @@ -290,7 +293,7 @@ def migrate_indexes( def save_aggregate_report_to_elasticsearch( aggregate_report: OrderedDict[str, Any], - index_suffix: Optional[str] = None, + index_suffix: str | None = None, monthly_indexes: bool = False, number_of_shards: int = 1, number_of_replicas: int = 0, @@ -400,7 +403,9 @@ def save_aggregate_report_to_elasticsearch( for spf_result in record["auth_results"]["spf"]: agg_doc.add_spf_result( - domain=spf_result["domain"], scope=spf_result["scope"], result=spf_result["result"] + domain=spf_result["domain"], + scope=spf_result["scope"], + result=spf_result["result"], ) index = "dmarc_aggregate" @@ -422,7 +427,7 @@ def save_aggregate_report_to_elasticsearch( def save_forensic_report_to_elasticsearch( forensic_report: OrderedDict[str, Any], - index_suffix: Optional[str] = None, + index_suffix: str | None = None, monthly_indexes: bool = False, number_of_shards: int = 1, number_of_replicas: int = 0, diff --git a/parsedmarc/kafkaclient.py b/parsedmarc/kafkaclient.py index 843e6a37..74d92475 100644 --- a/parsedmarc/kafkaclient.py +++ b/parsedmarc/kafkaclient.py @@ -1,8 +1,9 @@ +from __future__ import annotations + # Standard Library from collections import OrderedDict import json from ssl import SSLContext, create_default_context -from typing import List, Optional, Union # Installed from kafka import KafkaProducer @@ -17,21 +18,21 @@ class KafkaError(RuntimeError): """Raised when a Kafka error occurs""" - def __init__(self, message: Union[str, Exception]): + def __init__(self, message: str | Exception): if isinstance(message, Exception): message = repr(message) super().__init__(f"Kafka Error: {message}") return -class KafkaClient(object): +class KafkaClient: def __init__( self, - kafka_hosts: List[str], + kafka_hosts: list[str], ssl: bool = False, - username: Optional[str] = None, - password: Optional[str] = None, - ssl_context: Optional[SSLContext] = None, + username: str | None = None, + password: str | None = None, + ssl_context: SSLContext | None = None, ): """ Args: @@ -94,7 +95,7 @@ def generate_daterange(report): return date_range def save_aggregate_reports_to_kafka( - self, aggregate_reports: Union[OrderedDict, List[OrderedDict]], aggregate_topic: str + self, aggregate_reports: OrderedDict | list[OrderedDict], aggregate_topic: str ) -> None: """ Saves aggregate DMARC reports to Kafka @@ -135,7 +136,7 @@ def save_aggregate_reports_to_kafka( return def save_forensic_reports_to_kafka( - self, forensic_reports: Union[OrderedDict, List[OrderedDict]], forensic_topic: str + self, forensic_reports: OrderedDict | list[OrderedDict], forensic_topic: str ) -> None: """ Saves forensic DMARC reports to Kafka, sends individual diff --git a/parsedmarc/loganalytics.py b/parsedmarc/loganalytics.py index eb083252..a00e6ae0 100644 --- a/parsedmarc/loganalytics.py +++ b/parsedmarc/loganalytics.py @@ -1,5 +1,6 @@ +from __future__ import annotations + # Standard Library -from typing import Dict, List, Optional # Installed from azure.core.exceptions import HttpResponseError @@ -14,7 +15,7 @@ class LogAnalyticsException(Exception): """Errors originating from LogsIngestionClient""" -class LogAnalyticsClient(object): +class LogAnalyticsClient: """Azure Log Analytics Client Pushes the DMARC reports to Log Analytics via Data Collection Rules. @@ -30,8 +31,8 @@ def __init__( tenant_id: str, dce: str, dcr_immutable_id: str, - dcr_aggregate_stream: Optional[str] = None, - dcr_forensic_stream: Optional[str] = None, + dcr_aggregate_stream: str | None = None, + dcr_forensic_stream: str | None = None, ): """ Args: @@ -57,7 +58,7 @@ def __init__( self.logs_client = LogsIngestionClient(dce, credential=self._credential) return - def _publish_json(self, reports: List[Dict], dcr_stream: str) -> None: + def _publish_json(self, reports: list[dict], dcr_stream: str) -> None: """Publish DMARC reports to the given Data Collection Rule. Args: @@ -72,7 +73,7 @@ def _publish_json(self, reports: List[Dict], dcr_stream: str) -> None: return def publish_results( - self, results: Dict[str, List[Dict]], save_aggregate: bool, save_forensic: bool + self, results: dict[str, list[dict]], save_aggregate: bool, save_forensic: bool ) -> None: """Publish DMARC reports to Log Analytics via Data Collection Rules (DCR). diff --git a/parsedmarc/mail/__init__.py b/parsedmarc/mail/__init__.py index 3bbe4e8b..00acd1f4 100644 --- a/parsedmarc/mail/__init__.py +++ b/parsedmarc/mail/__init__.py @@ -4,4 +4,9 @@ from parsedmarc.mail.imap import IMAPConnection from parsedmarc.mail.mailbox_connection import MailboxConnection -__all__ = ["MailboxConnection", "MSGraphConnection", "GmailConnection", "IMAPConnection"] +__all__ = [ + "MailboxConnection", + "MSGraphConnection", + "GmailConnection", + "IMAPConnection", +] diff --git a/parsedmarc/mail/gmail.py b/parsedmarc/mail/gmail.py index 46884882..cdfef011 100644 --- a/parsedmarc/mail/gmail.py +++ b/parsedmarc/mail/gmail.py @@ -1,9 +1,11 @@ +from __future__ import annotations + # Standard Library from base64 import urlsafe_b64decode from functools import lru_cache from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING # Installed from google.auth.transport.requests import Request @@ -50,7 +52,7 @@ def __init__( self, token_file: str, credentials_file: str, - scopes: List[str], + scopes: list[str], include_spam_trash: bool, reports_folder: str, oauth2_port: int, @@ -75,7 +77,7 @@ def create_folder(self, folder_name: str) -> None: return logger.debug(f"Creating label {folder_name}") - request_body: "Label" = {"name": folder_name, "messageListVisibility": "show"} + request_body: Label = {"name": folder_name, "messageListVisibility": "show"} try: self.service.users().labels().create(userId="me", body=request_body).execute() except HttpError as e: @@ -85,13 +87,15 @@ def create_folder(self, folder_name: str) -> None: raise e return - def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: + def fetch_messages(self, reports_folder: str, **kwargs) -> list[str]: reports_label_id = self._find_label_id_for_label(reports_folder) results = ( self.service.users() .messages() .list( - userId="me", includeSpamTrash=self.include_spam_trash, labelIds=[reports_label_id] + userId="me", + includeSpamTrash=self.include_spam_trash, + labelIds=[reports_label_id], ) .execute() ) @@ -111,7 +115,7 @@ def delete_message(self, message_id: str) -> None: def move_message(self, message_id: str, folder_name: str): label_id = self._find_label_id_for_label(folder_name) logger.debug(f"Moving message UID {message_id} to {folder_name}") - request_body: "ModifyMessageRequest" = { + request_body: ModifyMessageRequest = { "addLabelIds": [label_id], "removeLabelIds": [self.reports_label_id], } diff --git a/parsedmarc/mail/graph.py b/parsedmarc/mail/graph.py index 33b70984..473ce8a6 100644 --- a/parsedmarc/mail/graph.py +++ b/parsedmarc/mail/graph.py @@ -1,9 +1,11 @@ +from __future__ import annotations + # Standard Library from enum import Enum from functools import lru_cache from pathlib import Path from time import sleep -from typing import Any, Dict, List, Optional, Union +from typing import Any, TypeAlias # Installed from azure.identity import ( @@ -26,11 +28,11 @@ class AuthMethod(Enum): ClientSecret = 3 -Credential = Union[DeviceCodeCredential, UsernamePasswordCredential, ClientSecretCredential] +Credential: TypeAlias = "DeviceCodeCredential | UsernamePasswordCredential | ClientSecretCredential" def _get_cache_args(token_path: Path, allow_unencrypted_storage: bool): - cache_args: Dict[str, Any] = { + cache_args: dict[str, Any] = { "cache_persistence_options": TokenCachePersistenceOptions( name="parsedmarc", allow_unencrypted_storage=allow_unencrypted_storage, @@ -42,7 +44,7 @@ def _get_cache_args(token_path: Path, allow_unencrypted_storage: bool): return cache_args -def _load_token(token_path: Path) -> Optional[str]: +def _load_token(token_path: Path) -> str | None: if not token_path.exists(): return None with token_path.open() as token_file: @@ -64,7 +66,8 @@ def _generate_credential(auth_method: str, token_path: Path, **kwargs) -> Creden disable_automatic_authentication=True, tenant_id=kwargs["tenant_id"], **_get_cache_args( - token_path, allow_unencrypted_storage=kwargs["allow_unencrypted_storage"] + token_path, + allow_unencrypted_storage=kwargs["allow_unencrypted_storage"], ), ) return credential @@ -77,7 +80,8 @@ def _generate_credential(auth_method: str, token_path: Path, **kwargs) -> Creden username=kwargs["username"], password=kwargs["password"], **_get_cache_args( - token_path, allow_unencrypted_storage=kwargs["allow_unencrypted_storage"] + token_path, + allow_unencrypted_storage=kwargs["allow_unencrypted_storage"], ), ) return credential @@ -130,7 +134,7 @@ def __init__( token_path=token_path, allow_unencrypted_storage=allow_unencrypted_storage, ) - client_params: Dict[str, Any] = {"credential": credential} + client_params: dict[str, Any] = {"credential": credential} if isinstance(credential, (DeviceCodeCredential, UsernamePasswordCredential)): scopes = ["Mail.ReadWrite"] # Detect if mailbox is shared @@ -157,13 +161,13 @@ def create_folder(self, folder_name: str) -> None: request_url = f"/users/{self.mailbox_name}/mailFolders{sub_url}" resp = self._client.post(request_url, json=request_body) if resp.status_code == 409: - logger.debug(f"Folder {folder_name} already exists, " f"skipping creation") + logger.debug(f"Folder {folder_name} already exists, skipping creation") elif resp.status_code == 201: logger.debug(f"Created folder {folder_name}") else: - logger.warning(f"Unknown response " f"{resp.status_code} {resp.json()}") + logger.warning(f"Unknown response {resp.status_code} {resp.json()}") - def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: + def fetch_messages(self, reports_folder: str, **kwargs) -> list[str]: """Returns a list of message UIDs in the specified folder""" folder_id = self._find_folder_id_from_folder_path(reports_folder) url = f"/users/{self.mailbox_name}/mailFolders/" f"{folder_id}/messages" @@ -175,7 +179,7 @@ def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: def _get_all_messages(self, url: str, batch_size: int): messages: list - params: Dict[str, Any] = {"$select": "id"} + params: dict[str, Any] = {"$select": "id"} if batch_size and batch_size > 0: params["$top"] = batch_size else: @@ -199,17 +203,13 @@ def mark_message_read(self, message_id: str): url = f"/users/{self.mailbox_name}/messages/{message_id}" resp = self._client.patch(url, json={"isRead": "true"}) if resp.status_code != 200: - raise RuntimeWarning( - f"Failed to mark message read" f"{resp.status_code}: {resp.json()}" - ) + raise RuntimeWarning(f"Failed to mark message read {resp.status_code}: {resp.json()}") def fetch_message(self, message_id: str) -> str: url = f"/users/{self.mailbox_name}/messages/{message_id}/$value" result = self._client.get(url) if result.status_code != 200: - raise RuntimeWarning( - f"Failed to fetch message" f"{result.status_code}: {result.json()}" - ) + raise RuntimeWarning(f"Failed to fetch message {result.status_code}: {result.json()}") self.mark_message_read(message_id) return result.text @@ -217,7 +217,7 @@ def delete_message(self, message_id: str): url = f"/users/{self.mailbox_name}/messages/{message_id}" resp = self._client.delete(url) if resp.status_code != 204: - raise RuntimeWarning(f"Failed to delete message " f"{resp.status_code}: {resp.json()}") + raise RuntimeWarning(f"Failed to delete message {resp.status_code}: {resp.json()}") def move_message(self, message_id: str, folder_name: str): folder_id = self._find_folder_id_from_folder_path(folder_name) @@ -225,7 +225,7 @@ def move_message(self, message_id: str, folder_name: str): url = f"/users/{self.mailbox_name}/messages/{message_id}/move" resp = self._client.post(url, json=request_body) if resp.status_code != 201: - raise RuntimeWarning(f"Failed to move message " f"{resp.status_code}: {resp.json()}") + raise RuntimeWarning(f"Failed to move message {resp.status_code}: {resp.json()}") def keepalive(self): # Not needed @@ -249,7 +249,7 @@ def _find_folder_id_from_folder_path(self, folder_name: str) -> str: else: return self._find_folder_id_with_parent(folder_name, None) - def _find_folder_id_with_parent(self, folder_name: str, parent_folder_id: Optional[str]): + def _find_folder_id_with_parent(self, folder_name: str, parent_folder_id: str | None): sub_url = "" if parent_folder_id is not None: sub_url = f"/{parent_folder_id}/childFolders" @@ -257,7 +257,7 @@ def _find_folder_id_with_parent(self, folder_name: str, parent_folder_id: Option filter = f"?$filter=displayName eq '{folder_name}'" folders_resp = self._client.get(url + filter) if folders_resp.status_code != 200: - raise RuntimeWarning(f"Failed to list folders." f"{folders_resp.json()}") + raise RuntimeWarning(f"Failed to list folders. {folders_resp.json()}") folders: list = folders_resp.json()["value"] matched_folders = [folder for folder in folders if folder["displayName"] == folder_name] if len(matched_folders) == 0: diff --git a/parsedmarc/mail/imap.py b/parsedmarc/mail/imap.py index cb1c5016..e2fdd77d 100644 --- a/parsedmarc/mail/imap.py +++ b/parsedmarc/mail/imap.py @@ -1,7 +1,8 @@ +from __future__ import annotations + # Standard Library from socket import timeout from time import sleep -from typing import Optional # Installed from imapclient.exceptions import IMAPClientError @@ -17,10 +18,10 @@ class IMAPConnection(MailboxConnection): def __init__( self, - host: Optional[str] = None, - user: Optional[str] = None, - password: Optional[str] = None, - port: Optional[int] = None, + host: str | None = None, + user: str | None = None, + password: str | None = None, + port: int | None = None, ssl: bool = True, verify: bool = True, timeout: int = 30, diff --git a/parsedmarc/mail/mailbox_connection.py b/parsedmarc/mail/mailbox_connection.py index a36cbbdc..68015f08 100644 --- a/parsedmarc/mail/mailbox_connection.py +++ b/parsedmarc/mail/mailbox_connection.py @@ -1,6 +1,7 @@ +from __future__ import annotations + # Standard Library from abc import ABC -from typing import List, Union class MailboxConnection(ABC): @@ -9,10 +10,10 @@ class MailboxConnection(ABC): def create_folder(self, folder_name: str) -> None: raise NotImplementedError - def fetch_messages(self, reports_folder: str, **kwargs) -> List[str]: + def fetch_messages(self, reports_folder: str, **kwargs) -> list[str]: raise NotImplementedError - def fetch_message(self, message_id: str) -> Union[str, bytes]: + def fetch_message(self, message_id: str) -> str | bytes: raise NotImplementedError def delete_message(self, message_id: str) -> None: diff --git a/parsedmarc/s3.py b/parsedmarc/s3.py index 1c7181da..877634b5 100644 --- a/parsedmarc/s3.py +++ b/parsedmarc/s3.py @@ -1,6 +1,8 @@ +from __future__ import annotations + # Standard Library import json -from typing import Any, Dict, Optional +from typing import Any # Installed import boto3 @@ -10,17 +12,17 @@ from parsedmarc.utils import human_timestamp_to_datetime -class S3Client(object): +class S3Client: """A client for a Amazon S3""" def __init__( self, bucket_name: str, bucket_path: str, - region_name: Optional[str] = None, - endpoint_url: Optional[str] = None, - access_key_id: Optional[str] = None, - secret_access_key: Optional[str] = None, + region_name: str | None = None, + endpoint_url: str | None = None, + access_key_id: str | None = None, + secret_access_key: str | None = None, ): """ Args: @@ -52,15 +54,15 @@ def __init__( self.bucket = self.s3.Bucket(self.bucket_name) return - def save_aggregate_report_to_s3(self, report: Dict[str, Any]) -> None: + def save_aggregate_report_to_s3(self, report: dict[str, Any]) -> None: self.save_report_to_s3(report, "aggregate") return - def save_forensic_report_to_s3(self, report: Dict[str, Any]) -> None: + def save_forensic_report_to_s3(self, report: dict[str, Any]) -> None: self.save_report_to_s3(report, "forensic") return - def save_report_to_s3(self, report: Dict[str, Any], report_type: str): + def save_report_to_s3(self, report: dict[str, Any], report_type: str): report_date = human_timestamp_to_datetime(report["report_metadata"]["begin_date"]) report_id = report["report_metadata"]["report_id"] path_template = "{0}/{1}/year={2}/month={3:02d}/day={4:02d}/{5}.json" diff --git a/parsedmarc/splunk.py b/parsedmarc/splunk.py index 59066ec7..be5f7ec1 100644 --- a/parsedmarc/splunk.py +++ b/parsedmarc/splunk.py @@ -1,7 +1,9 @@ +from __future__ import annotations + # Standard Library import json import socket -from typing import Any, Dict, List, Union +from typing import Any from urllib.parse import urlparse # Installed @@ -19,14 +21,14 @@ class SplunkError(RuntimeError): """Raised when a Splunk API error occurs""" - def __init__(self, message: Union[str, Exception]): + def __init__(self, message: str | Exception): if isinstance(message, Exception): message = repr(message) super().__init__(f"Splunk Error: {message}") return -class HECClient(object): +class HECClient: """A client for a Splunk HTTP Events Collector (HEC)""" # http://docs.splunk.com/Documentation/Splunk/latest/Data/AboutHEC @@ -59,7 +61,7 @@ def __init__( self.session = requests.Session() self.timeout = timeout self.session.verify = verify - self._common_data: Dict[str, Any] = dict( + self._common_data: dict[str, Any] = dict( host=self.host, source=self.source, index=self.index ) @@ -69,9 +71,7 @@ def __init__( } return - def save_aggregate_reports_to_splunk( - self, aggregate_reports: Union[Dict, List[Dict[str, Any]]] - ): + def save_aggregate_reports_to_splunk(self, aggregate_reports: dict | list[dict[str, Any]]): """Save aggregate DMARC reports to Splunk Args: @@ -124,7 +124,7 @@ def save_aggregate_reports_to_splunk( raise SplunkError(response["text"]) return - def save_forensic_reports_to_splunk(self, forensic_reports: Union[Dict, List[Dict[str, Any]]]): + def save_forensic_reports_to_splunk(self, forensic_reports: dict | list[dict[str, Any]]): """Save forensic DMARC reports to Splunk Args: diff --git a/parsedmarc/syslog.py b/parsedmarc/syslog.py index 41021ff8..eefd27ab 100644 --- a/parsedmarc/syslog.py +++ b/parsedmarc/syslog.py @@ -10,7 +10,7 @@ ) -class SyslogClient(object): +class SyslogClient: """A client for Syslog""" def __init__(self, server_name: str, server_port: int): diff --git a/parsedmarc/utils.py b/parsedmarc/utils.py index f3ddd872..94d4d34e 100644 --- a/parsedmarc/utils.py +++ b/parsedmarc/utils.py @@ -1,5 +1,7 @@ """Utility functions that might be useful for other projects""" +from __future__ import annotations + # Standard Library import atexit import base64 @@ -15,7 +17,7 @@ import shutil import subprocess import tempfile -from typing import Any, Dict, List, Optional, Union +from typing import Any # Installed from dateutil.parser import parse as parse_date @@ -95,10 +97,10 @@ def get_base_domain(domain: str) -> str: def query_dns( domain: str, record_type: str, - cache: Optional[ExpiringDict] = None, - nameservers: Optional[List[str]] = None, + cache: ExpiringDict | None = None, + nameservers: list[str] | None = None, timeout: float = 2.0, -) -> List[str]: +) -> list[str]: """Make a DNS query Args: @@ -134,7 +136,10 @@ def query_dns( resolver.lifetime = timeout if record_type == "TXT": resource_records = list( - map(lambda r: r.strings, resolver.resolve(domain, record_type, lifetime=timeout)) + map( + lambda r: r.strings, + resolver.resolve(domain, record_type, lifetime=timeout), + ) ) _resource_record = [ resource_record[0][:0].join(resource_record) @@ -157,10 +162,10 @@ def query_dns( def get_reverse_dns( ip_address: str, - cache: Optional[ExpiringDict] = None, - nameservers: Optional[List[str]] = None, + cache: ExpiringDict | None = None, + nameservers: list[str] | None = None, timeout: float = 2.0, -) -> Optional[str]: +) -> str | None: """Resolve an IP address to a hostname using a reverse DNS query Args: @@ -172,7 +177,7 @@ def get_reverse_dns( Returns: The reverse DNS hostname (if any) """ - hostname: Optional[str] = None + hostname: str | None = None try: address = str(dns.reversename.from_address(ip_address)) hostname = query_dns(address, "PTR", cache=cache, nameservers=nameservers, timeout=timeout)[ @@ -240,7 +245,7 @@ def human_timestamp_to_timestamp(human_timestamp: str) -> float: return human_timestamp_to_datetime(human_timestamp).timestamp() -def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Optional[str]: +def get_ip_address_country(ip_address: str, db_path: str | None = None) -> str | None: """Get the ISO code for the country associated with the given IPv4 or IPv6 address Args: @@ -299,10 +304,10 @@ def get_ip_address_country(ip_address: str, db_path: Optional[str] = None) -> Op def get_ip_address_info( ip_address: str, - ip_db_path: Optional[str] = None, - cache: Optional[ExpiringDict] = None, + ip_db_path: str | None = None, + cache: ExpiringDict | None = None, offline: bool = False, - nameservers: Optional[List[str]] = None, + nameservers: list[str] | None = None, timeout: float = 2.0, parallel: bool = False, ) -> OrderedDict: @@ -358,7 +363,12 @@ def parse_email_address(original_address: str) -> OrderedDict: domain = address_parts[-1].lower() return OrderedDict( - [("display_name", display_name), ("address", address), ("local", local), ("domain", domain)] + [ + ("display_name", display_name), + ("address", address), + ("local", local), + ("domain", domain), + ] ) @@ -447,7 +457,7 @@ def convert_outlook_msg(msg_bytes: bytes) -> bytes: return rfc822 -def parse_email(data: Union[bytes, str], strip_attachment_payloads: bool = False) -> Dict[str, Any]: +def parse_email(data: bytes | str, strip_attachment_payloads: bool = False) -> dict[str, Any]: """A simplified email parser Args: diff --git a/pyproject.toml b/pyproject.toml index 37f22a0c..4f678157 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "azure-identity>=1.8.0", "azure-monitor-ingestion>=1.0.0", "boto3>=1.16.63", + "bs4", "dateparser>=1.1.1", "dnspython>=2.0.0", "elasticsearch-dsl==7.4.0", @@ -75,20 +76,29 @@ dev = [ "types-requests", "types-tqdm", "types-xmltodict", - ### dev.sh dependencies ## Formatting / Linting "validate-pyproject[all]", + "pyupgrade", "black", "isort", #"pylint", "mypy", - ## Testing "flake8", - #"pytest", - ## REPL - "bpython", + ## Testing + "codecov", +] + +docs = [ + "alabaster>=0.7.12", + "Babel>=2.7.0", + "docutils>=0.14,<0.18", + "myst-parser[linkify]>=0.18.0", + "pygments>=2.11.1", + "sphinx>=1.0.5", + "sphinx_rtd_theme>=0.4.3", ] + [project.scripts] parsedmarc = "parsedmarc.cli:_main" @@ -114,3 +124,7 @@ import_heading_stdlib = "Standard Library" import_heading_thirdparty = "Installed" import_heading_firstparty = "Package" import_heading_localfolder = "Local" + +[tool.black] +line-length = 100 +target-version = ["py37", "py38", "py39", "py310", "py311", "py312"] diff --git a/requirements.txt b/requirements.txt index 30dfa3f1..362641f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,42 +1,5 @@ -tqdm>=4.31.1 -pygments>=2.11.1 -dnspython>=2.0.0 -expiringdict>=1.1.4 -urllib3>=1.25.7 -requests>=2.22.0 -publicsuffixlist>=0.10.0 -xmltodict>=0.12.0 -geoip2>=3.0.0 -imapclient>=2.1.0 -dateparser>=1.1.1 -elasticsearch<7.14.0 -elasticsearch-dsl>=7.4.0 -kafka-python>=1.4.4 -mailsuite>=1.6.1 nose>=1.3.7 wheel>=0.37.0 -flake8>=3.7.8 jinja2>=2.10.1 packaging>=19.1 imagesize>=1.1.0 -alabaster>=0.7.12 -Babel>=2.7.0 -docutils<0.18,>=0.14 -sphinx>=1.0.5 -sphinx_rtd_theme>=0.4.3 -codecov>=2.0.15 -lxml>=4.4.0 -boto3>=1.16.63 -msgraph-core>=0.2.2 -azure-identity>=1.8.0 -azure-monitor-ingestion>=1.0.0 -google-api-core>=2.4.0 -google-api-python-client>=2.35.0 -google-auth>=2.3.3 -google-auth-httplib2>=0.1.0 -google-auth-oauthlib>=0.4.6 -hatch>=1.5.0 -myst-parser>=0.18.0 -myst-parser[linkify] -requests -bs4 From 88a089bcdc801d7c66e1d2aee37a33d750accab0 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Thu, 1 Feb 2024 02:51:50 +1100 Subject: [PATCH 07/15] Use typing extentions for TypeAlias --- parsedmarc/mail/graph.py | 8 +++++++- pyproject.toml | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/parsedmarc/mail/graph.py b/parsedmarc/mail/graph.py index 473ce8a6..6bb1e66e 100644 --- a/parsedmarc/mail/graph.py +++ b/parsedmarc/mail/graph.py @@ -5,7 +5,13 @@ from functools import lru_cache from pathlib import Path from time import sleep -from typing import Any, TypeAlias +import sys +from typing import Any + +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias # Installed from azure.identity import ( diff --git a/pyproject.toml b/pyproject.toml index 4f678157..79b4dafe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "publicsuffixlist>=0.10.0", "requests>=2.22.0", "tqdm>=4.31.1", + "typing-extensions;python_version<'3.10'", "urllib3>=1.25.7", "xmltodict>=0.12.0", ] From 64c50049d56804e7143921bfd9c30b7870ff6ed4 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Thu, 1 Feb 2024 03:01:31 +1100 Subject: [PATCH 08/15] Update dependencies --- mypy.ini | 10 ---------- pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/mypy.ini b/mypy.ini index 03056b65..879b3ffc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -38,14 +38,4 @@ ignore_missing_imports = True # https://github.com/cffnpwr/google-auth-oauthlib-stubs/issues/1 ignore_missing_imports = True -# pip install the following: -# lxml-stubs -# types-python-dateutil -# types-requests -# types-tqdm -# types-xmltodict -# google-api-python-client-stubs -# google-auth-stubs -# types-boto3 - # importlib_resources https://github.com/python/importlib_resources/blob/main/importlib_resources/py.typed:23 diff --git a/pyproject.toml b/pyproject.toml index 79b4dafe..fde95e1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "kafka-python>=1.4.4", "lxml>=4.4.0", "mailsuite>=1.6.1", - "msgraph-core>=0.2.2", + "msgraph-core>=0.2.2,<1.0.0", # msgraph-core 1.0 has breaking changes "publicsuffixlist>=0.10.0", "requests>=2.22.0", "tqdm>=4.31.1", From 0a78bed66ebc6dea48c4ba68bf5a3404963ee0f7 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 4 Feb 2024 15:00:37 +1100 Subject: [PATCH 09/15] Set static password for CI --- .github/workflows/python-tests.yml | 4 ++++ ci.ini | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 44b3e8aa..00f9e2f7 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -30,6 +30,10 @@ jobs: sudo apt-get update && sudo apt-get install elasticsearch sudo systemctl restart elasticsearch sudo systemctl --no-pager status elasticsearch + # Note: We set a static password as ES8 requires that a password is set. + # We can't use the randomly generateds one in our CI script so we set one here + # For real world applications you should NOT commit passwords to git like this. + echo "y\nWFXvAZ6xvcAhx\nWFXvAZ6xvcAhx" | sudo /usr/share/elasticsearch/bin/elasticsearch-reset-password --interactive -u elastic - name: Install Python dependencies run: | python -m pip install --upgrade pip diff --git a/ci.ini b/ci.ini index 3e35e9cf..04a57e24 100644 --- a/ci.ini +++ b/ci.ini @@ -4,8 +4,7 @@ save_forensic = True debug = True [elasticsearch] -hosts = http://localhost:9200 -ssl = False +hosts = https://elastic:WFXvAZ6xvcAhx@localhost:9200 number_of_shards=2 number_of_replicas=2 From 094f35794c21f944099f3cb6024cf3c70d389493 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 4 Feb 2024 15:11:18 +1100 Subject: [PATCH 10/15] Fix CI password --- .github/workflows/python-tests.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 00f9e2f7..204090e0 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -33,7 +33,13 @@ jobs: # Note: We set a static password as ES8 requires that a password is set. # We can't use the randomly generateds one in our CI script so we set one here # For real world applications you should NOT commit passwords to git like this. - echo "y\nWFXvAZ6xvcAhx\nWFXvAZ6xvcAhx" | sudo /usr/share/elasticsearch/bin/elasticsearch-reset-password --interactive -u elastic + # Note Syntax: https://github.com/orgs/community/discussions/25469#discussioncomment-3248006 + sudo /usr/share/elasticsearch/bin/elasticsearch-reset-password --interactive -u elastic <<'EOF' + y + WFXvAZ6xvcAhx + WFXvAZ6xvcAhx + EOF + - name: Install Python dependencies run: | python -m pip install --upgrade pip From c2089185f645cf17d4f579d8c698586ed773296e Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 4 Feb 2024 15:16:21 +1100 Subject: [PATCH 11/15] remove redundant black config --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fde95e1d..3dd6699c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,4 +128,3 @@ import_heading_localfolder = "Local" [tool.black] line-length = 100 -target-version = ["py37", "py38", "py39", "py310", "py311", "py312"] From 451832a6d865162090daaae3d30ab11f25b91c70 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 4 Feb 2024 15:46:03 +1100 Subject: [PATCH 12/15] Skip directories when using CLI --- parsedmarc/cli.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/parsedmarc/cli.py b/parsedmarc/cli.py index d52e7ce2..8d767952 100644 --- a/parsedmarc/cli.py +++ b/parsedmarc/cli.py @@ -11,6 +11,7 @@ import logging from multiprocessing import Pool, Value import os +import os.path from ssl import CERT_NONE, create_default_context import sys import time @@ -879,8 +880,8 @@ def process_reports(reports_): if is_mbox(file_path): mbox_paths.append(file_path) - file_paths = list(set(file_paths)) - mbox_paths = list(set(mbox_paths)) + file_paths = [p for p in set(file_paths) if os.path.isfile(p)] + mbox_paths = [p for p in set(mbox_paths) if os.path.isfile(p)] for mbox_path in mbox_paths: file_paths.remove(mbox_path) From b5d39bf26b0a5ab76e865c5770acfa9939c84f3e Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 4 Feb 2024 15:50:32 +1100 Subject: [PATCH 13/15] Add hatchling dependency, run black in CI --- .github/workflows/python-tests.yml | 1 + pyproject.toml | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 204090e0..18876c5f 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -51,6 +51,7 @@ jobs: - name: Check code style run: | flake8 *.py parsedmarc/*.py + black parsedmarc --check --diff - name: Run unit tests run: | coverage run tests.py diff --git a/pyproject.toml b/pyproject.toml index 3dd6699c..2a70f531 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,8 @@ dev = [ "flake8", ## Testing "codecov", + ## Building + "hatchling", ] docs = [ From e3059bb09f89e23bc8c4efdda9400908e996485e Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 4 Feb 2024 15:55:08 +1100 Subject: [PATCH 14/15] fix hatch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2a70f531..b9e16ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dev = [ ## Testing "codecov", ## Building - "hatchling", + "hatch", ] docs = [ From 6841201285c0ed5ba080ce7807bb121e45efe9a9 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Wed, 6 Mar 2024 13:55:35 +1100 Subject: [PATCH 15/15] Modify ci comments --- .github/workflows/python-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 18876c5f..026cb0f2 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -31,7 +31,7 @@ jobs: sudo systemctl restart elasticsearch sudo systemctl --no-pager status elasticsearch # Note: We set a static password as ES8 requires that a password is set. - # We can't use the randomly generateds one in our CI script so we set one here + # We can't use a randomly generated one in our CI script so we set one here # For real world applications you should NOT commit passwords to git like this. # Note Syntax: https://github.com/orgs/community/discussions/25469#discussioncomment-3248006 sudo /usr/share/elasticsearch/bin/elasticsearch-reset-password --interactive -u elastic <<'EOF'