From fb8b5c70dd3c342ee5f31c0454c56af9f0c333ad Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 29 Mar 2024 10:22:11 -0700 Subject: [PATCH] Fix newline terminator parsing --- modelscan/tools/picklescanner.py | 28 ++++++++++++--- tests/test_modelscan.py | 59 ++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index 80d7409..0320529 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -1,7 +1,7 @@ import logging import pickletools # nosec from tarfile import TarError -from typing import IO, Any, Dict, List, Set, Tuple, Union +from typing import IO, Any, Dict, List, Set, Tuple, Union, Optional import numpy as np @@ -17,15 +17,15 @@ class GenOpsError(Exception): - def __init__(self, msg: str): + def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]): self.msg = msg + self.globals = globals super().__init__() def __str__(self) -> str: return self.msg -# # TODO: handle methods loading other Pickle files (either mark as suspicious, or follow calls to scan other files [preventing infinite loops]) # # pickle.loads() @@ -62,7 +62,11 @@ def _list_globals( pickletools.genops(data) ) except Exception as e: - raise GenOpsError(str(e)) + # Given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle. + # Thus return the already found globals in the error & let the caller decide what to do. + globals_opt = globals if len(globals) > 0 else None + raise GenOpsError(str(e), globals_opt) + last_byte = data.read(1) data.seek(-1, 1) @@ -126,6 +130,12 @@ def scan_pickle_bytes( try: raw_globals = _list_globals(model.get_stream(), multiple_pickles) except GenOpsError as e: + if e.globals is not None: + return _build_scan_result_from_raw_globals( + e.globals, + model, + settings, + ) return ScanResults( issues, [ @@ -138,8 +148,16 @@ def scan_pickle_bytes( ], [], ) + logger.debug("Global imports in %s: %s", model, raw_globals, settings) + return _build_scan_result_from_raw_globals(raw_globals, model, settings) + - logger.debug("Global imports in %s: %s", model.get_source(), raw_globals) +def _build_scan_result_from_raw_globals( + raw_globals: Set[Tuple[str, str]], + model: Model, + settings: Dict[str, Any], +) -> ScanResults: + issues: List[Issue] = [] severities = { "CRITICAL": IssueSeverity.CRITICAL, "HIGH": IssueSeverity.HIGH, diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 890a705..72b69bf 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -134,6 +134,36 @@ def malicious13_gen() -> bytes: return p +def malicious14_gen() -> bytes: + p = b"".join( + [ + pickle.UNICODE + b"os\n", + pickle.PUT + b"2\n", + pickle.POP, + pickle.UNICODE + b"system\n", + pickle.PUT + b"3\n", + pickle.POP, + pickle.UNICODE + b"torch\n", + pickle.PUT + b"0\n", + pickle.POP, + pickle.UNICODE + b"LongStorage\n", + pickle.PUT + b"1\n", + pickle.POP, + pickle.GET + b"2\n", + pickle.GET + b"3\n", + pickle.STACK_GLOBAL, + pickle.MARK, + pickle.UNICODE + b"cat flag.txt\n", + pickle.TUPLE, + pickle.REDUCE, + pickle.STOP, + b"\n\n\t\t", + ] + ) + + return p + + def initialize_pickle_file(path: str, obj: Any, version: int) -> None: if not os.path.exists(path): with open(path, "wb") as file: @@ -288,6 +318,8 @@ def file_path(tmp_path_factory: Any) -> Any: initialize_data_file(f"{tmp}/data/malicious13.pkl", malicious13_gen()) + initialize_data_file(f"{tmp}/data/malicious14.pkl", malicious14_gen()) + return tmp @@ -950,6 +982,22 @@ def test_scan_pickle_operators(file_path: Any) -> None: malicious13.scan(Path(f"{file_path}/data/malicious13.pkl")) assert malicious13.issues.all_issues == expected_malicious13 + expected_malicious14 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "os", + "system", + IssueSeverity.CRITICAL, + f"{file_path}/data/malicious14.pkl", + ), + ) + ] + malicious14 = ModelScan() + malicious14.scan(Path(f"{file_path}/data/malicious14.pkl")) + assert malicious14.issues.all_issues == expected_malicious14 + def test_scan_directory_path(file_path: str) -> None: expected = { @@ -1204,6 +1252,16 @@ def test_scan_directory_path(file_path: str) -> None: f"{file_path}/data/malicious13.pkl", ), ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "os", + "system", + IssueSeverity.CRITICAL, + f"{file_path}/data/malicious14.pkl", + ), + ), } ms = ModelScan() p = Path(f"{file_path}/data/") @@ -1221,6 +1279,7 @@ def test_scan_directory_path(file_path: str) -> None: f"malicious11.pkl", f"malicious12.pkl", f"malicious13.pkl", + f"malicious14.pkl", f"malicious1_v0.dill", f"malicious1_v3.dill", f"malicious1_v4.dill",