Skip to content

Commit

Permalink
Merge branch 'main' into fix/flatten-command-type
Browse files Browse the repository at this point in the history
  • Loading branch information
JSCU-CNI authored Sep 25, 2024
2 parents 4b418c6 + 1701dcf commit d1ba5ce
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 19 deletions.
115 changes: 97 additions & 18 deletions flow/record/adapter/xlsx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import openpyxl
from base64 import b64decode, b64encode
from datetime import datetime, timezone
from typing import Any, Iterator

from openpyxl import Workbook, load_workbook
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE

from flow import record
from flow.record import fieldtypes
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.fieldtypes.net import ipaddress
from flow.record.selector import make_selector
from flow.record.utils import is_stdout

Expand All @@ -14,23 +21,72 @@
"""


def sanitize_fieldvalues(values: Iterator[Any]) -> Iterator[Any]:
"""Sanitize field values so openpyxl will accept them."""

for value in values:
# openpyxl doesn't support timezone-aware datetime instances,
# so we convert to UTC and then remove the timezone info.
if isinstance(value, datetime) and value.tzinfo is not None:
value = value.astimezone(timezone.utc).replace(tzinfo=None)

elif type(value) in [ipaddress, list, fieldtypes.posix_path, fieldtypes.windows_path]:
value = str(value)

elif isinstance(value, bytes):
base64_encode = False
try:
new_value = 'b"' + value.decode() + '"'
if ILLEGAL_CHARACTERS_RE.search(new_value):
base64_encode = True
else:
value = new_value
except UnicodeDecodeError:
base64_encode = True
if base64_encode:
value = "base64:" + b64encode(value).decode()

yield value


class XlsxWriter(AbstractWriter):
fp = None
wb = None

def __init__(self, path, **kwargs):
self.fp = record.open_path_or_stream(path, "wb")
self.wb = openpyxl.Workbook()
self.wb = Workbook()
self.ws = self.wb.active
self.desc = None
# self.ws.title = "Records"

# Remove the active work sheet, every Record Descriptor will have its own sheet.
self.wb.remove(self.ws)
self.descs = []
self._last_dec = None

def write(self, r):
if not self.desc:
self.desc = r._desc
self.ws.append(r._desc.fields)
if r._desc not in self.descs:
self.descs.append(r._desc)
ws = self.wb.create_sheet(r._desc.name.strip().replace("/", "-"))
field_types = []
field_names = []

for field_name, field in r._desc.get_all_fields().items():
field_types.append(field.typename)
field_names.append(field_name)

ws.append(field_types)
ws.append(field_names)

if r._desc != self._last_dec:
self._last_dec = r._desc
self.ws = self.wb[r._desc.name.strip().replace("/", "-")]

values = list(sanitize_fieldvalues(value for value in r._asdict().values()))

self.ws.append(r._asdict().values())
try:
self.ws.append(values)
except ValueError as e:
raise ValueError(f"Unable to write values to workbook: {str(e)}")

def flush(self):
if self.wb:
Expand All @@ -53,7 +109,7 @@ def __init__(self, path, selector=None, **kwargs):
self.selector = make_selector(selector)
self.fp = record.open_path_or_stream(path, "rb")
self.desc = None
self.wb = openpyxl.load_workbook(self.fp)
self.wb = load_workbook(self.fp)
self.ws = self.wb.active

def close(self):
Expand All @@ -62,12 +118,35 @@ def close(self):
self.fp = None

def __iter__(self):
desc = None
for row in self.ws.rows:
if not desc:
desc = record.RecordDescriptor([col.value.replace(" ", "_").lower() for col in row])
continue

obj = desc(*[col.value for col in row])
if not self.selector or self.selector.match(obj):
yield obj
for worksheet in self.wb.worksheets:
desc = None
desc_name = worksheet.title.replace("-", "/")
field_names = None
field_types = None
for row in worksheet:
if field_types is None:
field_types = [col.value for col in row if col.value]
continue
if field_names is None:
field_names = [
col.value.replace(" ", "_").lower()
for col in row
if col.value and not col.value.startswith("_")
]
desc = record.RecordDescriptor(desc_name, list(zip(field_types, field_names)))
continue

record_values = []
for idx, col in enumerate(row):
value = col.value
if field_types[idx] == "bytes":
if value[1] == '"': # If so, we know this is b""
# Cut of the b" at the start and the trailing "
value = value[2:-1].encode()
else:
# If not, we know it is base64 encoded (so we cut of the starting 'base64:')
value = b64decode(value[7:])
record_values.append(value)
obj = desc(*record_values)
if not self.selector or self.selector.match(obj):
yield obj
3 changes: 2 additions & 1 deletion flow/record/fieldtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def flow_record_tz(*, default_tz: str = "UTC") -> Optional[ZoneInfo | UTC]:
try:
return ZoneInfo(tz)
except ZoneInfoNotFoundError as exc:
warnings.warn(f"{exc!r}, falling back to timezone.utc")
if tz != "UTC":
warnings.warn(f"{exc!r}, falling back to timezone.utc")
return UTC


Expand Down
55 changes: 55 additions & 0 deletions tests/test_xlsx_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import re
import sys
from datetime import datetime, timedelta, timezone
from typing import Iterator
from unittest.mock import MagicMock

import pytest

from flow.record import fieldtypes


@pytest.fixture
def mock_openpyxl_package(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]:
with monkeypatch.context() as m:
mock_openpyxl = MagicMock()
mock_cell = MagicMock()
mock_cell.ILLEGAL_CHARACTERS_RE = re.compile(r"[\000-\010]|[\013-\014]|[\016-\037]")
m.setitem(sys.modules, "openpyxl", mock_openpyxl)
m.setitem(sys.modules, "openpyxl.cell.cell", mock_cell)

yield mock_openpyxl


def test_sanitize_field_values(mock_openpyxl_package):
from flow.record.adapter.xlsx import sanitize_fieldvalues

assert list(
sanitize_fieldvalues(
[
7,
datetime(1920, 11, 11, 13, 37, 0, tzinfo=timezone(timedelta(hours=2))),
"James",
b"Bond",
b"\x00\x07",
fieldtypes.net.ipaddress("13.37.13.37"),
["Shaken", "Not", "Stirred"],
fieldtypes.posix_path("/home/user"),
fieldtypes.posix_command("/bin/bash -c 'echo hello world'"),
fieldtypes.windows_path("C:\\Users\\user\\Desktop"),
fieldtypes.windows_command("C:\\Some.exe /?"),
]
)
) == [
7,
datetime(1920, 11, 11, 11, 37, 0), # UTC normalization
"James",
'b"Bond"', # When possible, encode bytes in a printable way
"base64:AAc=", # If not, base64 encode
"13.37.13.37", # Stringify an ip address
"['Shaken', 'Not', 'Stirred']", # Stringify a list
"/home/user", # Stringify a posix path
"/bin/bash -c 'echo hello world'", # Stringify a posix command
"C:\\Users\\user\\Desktop", # Stringify a windows path
"C:\\Some.exe /?", # Stringify a windows command
]

0 comments on commit d1ba5ce

Please sign in to comment.