Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix newline terminator parsing #124

Merged
merged 1 commit into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import pickletools # nosec
from tarfile import TarError
from typing import IO, Any, Dict, List, Set, Tuple, Union
from typing import IO, Any, Dict, List, Set, Tuple, Union, Optional

import numpy as np

Expand All @@ -17,15 +17,15 @@


class GenOpsError(Exception):
def __init__(self, msg: str):
def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]):
self.msg = msg
self.globals = globals
super().__init__()

def __str__(self) -> str:
return self.msg


#
# TODO: handle methods loading other Pickle files (either mark as suspicious, or follow calls to scan other files [preventing infinite loops])
#
# pickle.loads()
Expand Down Expand Up @@ -62,7 +62,11 @@ def _list_globals(
pickletools.genops(data)
)
except Exception as e:
raise GenOpsError(str(e))
# Given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle.
# Thus return the already found globals in the error & let the caller decide what to do.
globals_opt = globals if len(globals) > 0 else None
raise GenOpsError(str(e), globals_opt)

last_byte = data.read(1)
data.seek(-1, 1)

Expand Down Expand Up @@ -126,6 +130,12 @@ def scan_pickle_bytes(
try:
raw_globals = _list_globals(model.get_stream(), multiple_pickles)
except GenOpsError as e:
if e.globals is not None:
return _build_scan_result_from_raw_globals(
e.globals,
model,
settings,
)
return ScanResults(
issues,
[
Expand All @@ -138,8 +148,16 @@ def scan_pickle_bytes(
],
[],
)
logger.debug("Global imports in %s: %s", model, raw_globals, settings)
return _build_scan_result_from_raw_globals(raw_globals, model, settings)


logger.debug("Global imports in %s: %s", model.get_source(), raw_globals)
def _build_scan_result_from_raw_globals(
raw_globals: Set[Tuple[str, str]],
model: Model,
settings: Dict[str, Any],
) -> ScanResults:
issues: List[Issue] = []
severities = {
"CRITICAL": IssueSeverity.CRITICAL,
"HIGH": IssueSeverity.HIGH,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def malicious13_gen() -> bytes:
return p


def malicious14_gen() -> bytes:
p = b"".join(
[
pickle.UNICODE + b"os\n",
pickle.PUT + b"2\n",
pickle.POP,
pickle.UNICODE + b"system\n",
pickle.PUT + b"3\n",
pickle.POP,
pickle.UNICODE + b"torch\n",
pickle.PUT + b"0\n",
pickle.POP,
pickle.UNICODE + b"LongStorage\n",
pickle.PUT + b"1\n",
pickle.POP,
pickle.GET + b"2\n",
pickle.GET + b"3\n",
pickle.STACK_GLOBAL,
pickle.MARK,
pickle.UNICODE + b"cat flag.txt\n",
pickle.TUPLE,
pickle.REDUCE,
pickle.STOP,
b"\n\n\t\t",
]
)

return p


def initialize_pickle_file(path: str, obj: Any, version: int) -> None:
if not os.path.exists(path):
with open(path, "wb") as file:
Expand Down Expand Up @@ -288,6 +318,8 @@ def file_path(tmp_path_factory: Any) -> Any:

initialize_data_file(f"{tmp}/data/malicious13.pkl", malicious13_gen())

initialize_data_file(f"{tmp}/data/malicious14.pkl", malicious14_gen())

return tmp


Expand Down Expand Up @@ -950,6 +982,22 @@ def test_scan_pickle_operators(file_path: Any) -> None:
malicious13.scan(Path(f"{file_path}/data/malicious13.pkl"))
assert malicious13.issues.all_issues == expected_malicious13

expected_malicious14 = [
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"os",
"system",
IssueSeverity.CRITICAL,
f"{file_path}/data/malicious14.pkl",
),
)
]
malicious14 = ModelScan()
malicious14.scan(Path(f"{file_path}/data/malicious14.pkl"))
assert malicious14.issues.all_issues == expected_malicious14


def test_scan_directory_path(file_path: str) -> None:
expected = {
Expand Down Expand Up @@ -1204,6 +1252,16 @@ def test_scan_directory_path(file_path: str) -> None:
f"{file_path}/data/malicious13.pkl",
),
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"os",
"system",
IssueSeverity.CRITICAL,
f"{file_path}/data/malicious14.pkl",
),
),
}
ms = ModelScan()
p = Path(f"{file_path}/data/")
Expand All @@ -1221,6 +1279,7 @@ def test_scan_directory_path(file_path: str) -> None:
f"malicious11.pkl",
f"malicious12.pkl",
f"malicious13.pkl",
f"malicious14.pkl",
f"malicious1_v0.dill",
f"malicious1_v3.dill",
f"malicious1_v4.dill",
Expand Down
Loading