Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MappedBinaryIO, Testimplementation for alternating KaitaiStream - maybe #76

Open
Hypnootika opened this issue Oct 11, 2023 · 15 comments
Open

Comments

@Hypnootika
Copy link

Hypnootika commented Oct 11, 2023

Hello everyone,

im decently new to working with binary files and KaitaiStruct. I love it but i unfortunately dont like the ReadWriteStruct.

I created a different approach based on the Python Runtime and i would like to have some feedback about possible improvements (and / or / or why) thats not suitable for Kaitai.

Please be kind with me, thats my first "package" and definitely the first mmap impl. i created.

The overall intention is (if you guys like the approach) that i would try to convert it and improve it further ( and create a new /different compiler-mode).

If you see mistakes or not logical implementations, please tell me. I want to learn!

Edit1: Note, there are obviously a lot of functions missing that Kaitai needs. This is just my usecase i currently build this around. Take it as a Prototype for a possible mmap approach.

Edit2: About the performance: I cant really say much at the moment but just by testing this, i already noticed a gain in speed (IDE runs the code a lot faster). Thats obviously a really bad comparison but if someone is interested, i could do tests aswell

import os
import struct
from mmap import mmap, ACCESS_COPY
from typing import List, Union


class Parser:
    """Parser class for binary data"""

    struct_mapping = {
        "u2be": struct.Struct(">H"),
        "u4be": struct.Struct(">I"),
        "u8be": struct.Struct(">Q"),
        "u2le": struct.Struct("<H"),
        "u4le": struct.Struct("<I"),
        "u8le": struct.Struct("<Q"),
        "s1": struct.Struct("b"),
        "s2be": struct.Struct(">h"),
        "s4be": struct.Struct(">i"),
        "s8be": struct.Struct(">q"),
        "s2le": struct.Struct("<h"),
        "s4le": struct.Struct("<i"),
        "s8le": struct.Struct("<q"),
        "f4be": struct.Struct(">f"),
        "f8be": struct.Struct(">d"),
        "f4le": struct.Struct("<f"),
        "f8le": struct.Struct("<d"),
        "u1": struct.Struct("B"),
    }

    range_mapping = {
        "u2be": (0, 65535),
        "u4be": (0, 4294967295),
        "u8be": (0, 18446744073709551615),
        "u2le": (0, 65535),
        "u4le": (0, 4294967295),
        "u8le": (0, 18446744073709551615),
        "s1": (-128, 127),
        "s2be": (-32768, 32767),
        "s4be": (-2147483648, 2147483647),
        "s8be": (-9223372036854775808, 9223372036854775807),
        "s2le": (-32768, 32767),
        "s4le": (-2147483648, 2147483647),
        "s8le": (-9223372036854775808, 9223372036854775807),
        "u1": (0, 255),
        "f4be": (-3.4e38, 3.4e38),
        "f8be": (-1.8e308, 1.8e308),
        "f4le": (-3.4e38, 3.4e38),
        "f8le": (-1.8e308, 1.8e308),
    }

    @classmethod
    def is_value_in_range(cls, pattern_id: str, value: Union[int, float]) -> bool:
        """Check if value is in range of pattern_id"""
        min_value, max_value = cls.range_mapping.get(pattern_id, (None, None))
        if min_value is None or max_value is None:
            raise ValueError(f"Pattern ID {pattern_id} not found.")
        return min_value <= value <= max_value

    @classmethod
    def pack_value(cls, pattern_id: str, value: Union[int, float]) -> bytes:
        """Convert value to bytes"""
        if not cls.is_value_in_range(pattern_id, value):
            raise ValueError(f"Value {value} out of range for pattern ID {pattern_id}.")
        struct_pattern = cls.struct_mapping.get(pattern_id)
        if struct_pattern is None:
            raise ValueError(f"Invalid pattern ID {pattern_id}.")
        return struct_pattern.pack(value)

    def read(self, data: bytes, pattern_id: str) -> bytes:
        """Read bytes from data"""
        size = self.struct_mapping.get(pattern_id, struct.Struct("")).size
        return data[:size]

    def read_value(self, data: bytes, pattern_id: str) -> Union[int, float]:
        """Read value from data"""
        packed_data = self.read(data, pattern_id)
        return self.struct_mapping[pattern_id].unpack(packed_data)[0]

    def read_array(
        self, data: bytes, count: int, pattern_id: str
    ) -> List[Union[int, float]]:
        """Read array of values from data"""
        size = self.struct_mapping[pattern_id].size
        return [
            self.read_value(data[i : i + size], pattern_id)
            for i in range(0, count * size, size)
        ]


class BaseMappedBinary:
    def __init__(self, file_path: str, output_file_path: str = None):
        self.file_path = file_path
        self.output_file_path = output_file_path
        if not os.path.exists(self.file_path):
            self.file = open(self.file_path, "w+b")
        else:
            self.file = open(self.file_path, "r+b")
        self.mapped_file = mmap(self.file.fileno(), 0, access=ACCESS_COPY)
        self.offset = 0
        self.parser = Parser()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def _read_from_offset(self, size: int) -> bytes:
        return self.mapped_file[self.offset : self.offset + size]

    def _update_offset(self, size: int):
        self.offset += size

    def close(self):
        self.mapped_file.close()
        self.file.close()

    def seek(self, offset: int) -> int:
        """Seek to offset"""
        self.offset = offset
        return self.offset

    def tell(self) -> int:
        """Return current offset"""
        return self.offset

    def flush(self):
        self.mapped_file.flush()


class MappedBinaryReader(BaseMappedBinary):
    def __init__(self, file_path: str):
        super().__init__(file_path, output_file_path=None)

    def read(self, pattern_id: str) -> bytes:
        return self.parser.read(
            self._read_from_offset(self.parser.struct_mapping[pattern_id].size),
            pattern_id,
        )

    def read_value(self, pattern_id: str) -> Union[int, float]:
        size = self.parser.struct_mapping[pattern_id].size
        value = self.parser.read_value(self._read_from_offset(size), pattern_id)
        self._update_offset(size)
        return value

    def read_array(self, count: int, pattern_id: str) -> List[Union[int, float]]:
        size = self.parser.struct_mapping[pattern_id].size
        values = self.parser.read_array(
            self._read_from_offset(count * size), count, pattern_id
        )
        self._update_offset(count * size)
        return values

    def read_string(self, count: int) -> str:
        """Read string from data"""
        value = self._read_from_offset(count).decode("utf-8")
        self._update_offset(count)
        return value

    def read_string_array(self, count: int) -> List[str]:
        """Read array of strings from data"""
        return [self.read_string(count) for _ in range(count)]

    def read_string_array_with_count(self) -> List[str]:
        """Read array of strings from data"""
        count = self.read_value("u4le")
        return self.read_string_array(count)

    def read_string_with_count(self) -> str:
        """Read string from data"""
        count = self.read_value("u4le")
        return self.read_string(count)

    def read_bytes(self, count: int) -> bytes:
        """Read bytes from data"""
        return self._read_from_offset(count)

    def read_bytes_with_count(self) -> bytes:
        """Read bytes from data"""
        count = self.read_value("u4le")
        return self._read_from_offset(count)

    def read_value_array_with_count(self, pattern_id: str) -> List[Union[int, float]]:
        """Read array of values from data"""
        count = self.read_value("u4le")
        return self.read_array(count, pattern_id)

    def read_value_array(self, count: int, pattern_id: str) -> List[Union[int, float]]:
        """Read array of values from data"""
        return self.read_array(count, pattern_id)


class MappedBinaryWriter(BaseMappedBinary):
    def __init__(self, file_path: str):
        super().__init__(file_path, output_file_path=None)
        self.data = b""

    def get_data(self) -> bytes:
        """Return the collected data as bytes"""
        return self.data

    def write(self, pattern_id: str, value: Union[int, float]) -> None:
        """Write value to data"""
        self.data += self.parser.pack_value(pattern_id, value)

    def write_value(self, pattern_id: str, value: Union[int, float]) -> None:
        """Write value to data"""
        self.write(pattern_id, value)

    def write_array(self, pattern_id: str, values: List[Union[int, float]]) -> None:
        """Write array of values to data"""
        for value in values:
            self.write_value(pattern_id, value)

    def write_value_array(
        self, pattern_id: str, values: List[Union[int, float]]
    ) -> None:
        """Write array of values to data"""
        self.write_array(pattern_id, values)

    def write_bytes(self, value: bytes) -> None:
        """Write bytes to data"""
        self.data += value

    def write_bytes_with_count(self, value: bytes) -> None:
        """Write bytes to data"""
        self.write_value("u4le", len(value))
        self.write_bytes(value)

    def write_string(self, value: str) -> None:
        """Write string to data"""
        self.data += value.encode("utf-8")

    def write_string_array(self, values: List[str]) -> None:
        """Write array of strings to data"""
        for value in values:
            self.write_string(value)

    def write_string_array_with_count(self, values: List[str]) -> None:
        """Write array of strings to data"""
        self.write_value("u4le", len(values))
        self.write_string_array(values)

    def write_string_with_count(self, value: str) -> None:
        """Write string to data"""
        self.write_value("u4le", len(value))
        self.write_string(value)

    def write_value_array_with_count(
        self, pattern_id: str, values: List[Union[int, float]]
    ) -> None:
        """Write array of values to data"""
        self.write_value("u4le", len(values))
        self.write_array(pattern_id, values)


class MappedBinaryIO(MappedBinaryReader, MappedBinaryWriter):
    def __init__(self, file_path: str, output_file_path: str = None):
        self.file_path = file_path

        if output_file_path is None:
            self.output_file_path = file_path + ".bin"
        else:
            self.output_file_path = output_file_path
        self.reader = MappedBinaryReader(self.file_path)
        self.writer = MappedBinaryWriter(self.file_path)

    def read_value(self, pattern_id: str) -> Union[int, float]:
        return self.reader.read_value(pattern_id)

    def write_value(self, pattern_id: str, value: Union[int, float]) -> None:
        self.writer.write_value(pattern_id, value)

    def flush(self) -> None:
        self.writer.flush()

    def seek(self, offset: int) -> int:
        return self.reader.seek(offset)

    def tell(self) -> int:
        return self.reader.tell()

    def close(self) -> None:
        self.reader.close()
        self.writer.close()

and a testfile class:



class ExpFile(MappedBinaryIO):
    def __init__(self, file_path: str, output_file_path: str = None):
        super().__init__(file_path)
        self._read()
        self.data = self.writer.get_data()
        if output_file_path is None:
            self.output_file_path = file_path + ".bin"
        else:
            self.output_file_path = output_file_path
        self.mapped_file = self.reader.mapped_file

    def _read(self):
        self.magic = self.reader.read_string(4)
        self.version = self.reader.read_value("u2le")
        self.uk = self.reader.read_value("u4le")
        self.header_size = self.reader.read_value("u4le")

    def __repr__(self):
        return (
            f"ExpFile({self.magic=}, {self.version=}, {self.uk=}, {self.header_size=})"
        )

    def _write(self):
        self.writer.write_string(self.magic)
        self.writer.write("u2le", self.version)
        self.writer.write("u4le", self.uk)
        self.writer.write("u4le", self.header_size)
        return self.writer.get_data()

    def write_to_file(self):
        with open(self.output_file_path, "wb") as f:
            f.write(self._write())


if __name__ == "__main__":
    mt = ExpFile(r"D:\binparser\eso0001.dat")
    mt.write_to_file()
    print(mt)
    print(mt.tell())
@Hypnootika
Copy link
Author

Hypnootika commented Oct 11, 2023

#67 might be related

@KOLANICH
Copy link
Contributor

read_array

Python has support of native typed arrays. Though I'm not sure they interplay well with mmap. But also there is ctypes with support if C-structs and arrays of them and I wonder if there is any benefits of utilizing them.

It seems that the current code still doubles memory usage since it parses native types and reads them into fields of a Python object. If we just wanna use memory mapping, would just passing a mmap object into KaitaiStream solve the issue? What I meant in #65 is kinda generating accessors for all the fields: when you get or set a value, you just dereference a pointer to raw memory, no data is duplicated into KS objects. It feels like that non-stream model cannot be properly implemented in Python runtime alone, but would require the compiler to generate the needed code directly.

@Hypnootika
Copy link
Author

Well, yes. I guess 50% of the performance deficit is more a bad (and repetitive) implementation.

I had exceptional results with reading. But as soon as it comes to writing the problems begin.

First of all, you cant mmap an empty file on windows which is really annoying.

About Integrating it into the current Runtime:

Yes, should be possible but that would require changes to the Compiler aswell. I could figure out some prototypes for mmaped kaitai in Python but (don't laugh, 3 weeks ago i didn't even know what scala is. I tried changing stuff around a bit but (dont get me wrong), the current Compiler is really messy (for me).

About the arrays:
It was my experimental approach to reduce for loops and if statements to parse stuff like records. Unfortunately, after i played around with my implementation it felt kinda redundant.

About the memory:

Currently im using ACCESS_COPY and also created some kind of buffer with the write methods. I guess one could easily reduce memory usage by changing it to write.

Also closing the mmap file and the fileIO object is a bit wonky. As far as I can tell you need the fileIO object just for the mmap creation. After that It can be dropped.

@Hypnootika
Copy link
Author

But also there is ctypes with support if C-structs and arrays of them and I wonder if there is any benefits of utilizing them.

Well, I didnt touch ctypes anymore because it has no "proper" (convenient) way of handling endianess

@Hypnootika
Copy link
Author

So, i prepared some (hopefully comparable) approaches. If someone wanna play around with it:

import mmap
import time
from abc import ABC, abstractmethod
from ctypes import *
from functools import wraps
from struct import Struct
from types import TracebackType
from typing import Any, Iterator, Optional, Union, Type, List


def timeit(method):
    @wraps(method)
    def timed(*args, **kw):
        ts = time.time()
        result = method(*args, **kw)
        te = time.time()
        duration = te - ts
        print(f"{method.__module__}.{method.__qualname__}")
        print(f"{method.__name__}: {duration:.6f} sec")

        return result

    return timed


class FormatString:
    """A class that represents and manipulates format strings.

    This class provides a unified interface for working with format strings
    across different libraries like struct, ctypes, and kaitaistruct.
    """

    # Class attribute that maps ctypes to struct
    CTYPES_TO_STRUCT = {
        c_char: "c",
        c_byte: "b",
        c_ubyte: "B",
        c_short: "h",
        c_ushort: "H",
        c_int: "i",
        c_uint: "I",
        c_long: "l",
        c_ulong: "L",
        c_longlong: "q",
        c_ulonglong: "Q",
        c_float: "f",
        c_double: "d",
    }

    # Class attribute that maps struct to ctypes
    STRUCT_TO_CTYPES = {value: key for key, value in CTYPES_TO_STRUCT.items()}

    # Class attribute that maps kaitai to struct
    KAITAI_TO_STRUCT = {
        "s1": "c",
        "u1": "B",
        "s2le": "<h",
        "u2le": "<H",
        "s4le": "<i",
        "u4le": "<I",
        "s8le": "<q",
        "u8le": "<Q",
        "f4le": "<f",
        "f8le": "<d",
        "s2be": ">h",
        "u2be": ">H",
        "s4be": ">i",
        "u4be": ">I",
        "s8be": ">q",
        "u8be": ">Q",
        "f4be": ">f",
        "f8be": ">d",
    }

    # Class attribute that maps struct to kaitai
    STRUCT_TO_KAITAI = {value: key for key, value in KAITAI_TO_STRUCT.items()}

    def __init__(self, fmt: str):
        self.set_format(fmt)

    def set_format(self, fmt: str):
        if fmt in self.KAITAI_TO_STRUCT:
            self._fmt = self.KAITAI_TO_STRUCT[fmt]
        elif fmt in self.STRUCT_TO_CTYPES:
            self._fmt = fmt
        else:
            raise ValueError(f"Unsupported format: {fmt}")
        self._struct = Struct(self._fmt)
        self._kaitai = self.STRUCT_TO_KAITAI.get(self._fmt)
        self._ctypes = self.STRUCT_TO_CTYPES.get(self._fmt)

    @classmethod
    def from_ctypes_to_struct(cls, ctypes_format):
        """Converts a ctypes format to a struct format."""
        return cls.CTYPES_TO_STRUCT.get(ctypes_format)

    @classmethod
    def from_struct_to_ctypes(cls, struct_format):
        """Converts a struct format to a ctypes format."""
        return cls.STRUCT_TO_CTYPES.get(struct_format)

    @property
    def fmt(self):
        """Returns the format string."""
        return self._fmt

    @property
    def struct(self):
        """Returns the Struct object."""
        return self._struct

    @property
    def kaitai(self):
        """Returns the kaitai format string."""
        return self._kaitai

    @property
    def ctypes(self):
        """Returns the ctypes format string."""
        return self._ctypes

    @property
    def size(self):
        """Returns the size of the format."""
        return self._struct.size

    # Define other class methods
    def pack(self, *args):
        """Packs values into a binary string according to the format."""
        return self._struct.pack(*args)

    def unpack(self, *args):
        """Unpacks values from a binary string according to the format."""
        return self._struct.unpack(*args)

    def pack_into(self, buffer, offset, *args):
        """Packs values into a buffer according to the format."""
        self._struct.pack_into(buffer, offset, *args)

    def unpack_from(self, buffer, offset=0):
        """Unpacks values from a buffer according to the format."""
        return self._struct.unpack_from(buffer, offset)

    def __getitem__(self, index):
        """Gets a field from the Struct object."""
        return self._struct[index]

    def __repr__(self):
        """Returns a string representation of the FormatString object."""
        return "FormatString(%r)" % self._fmt

    def __str__(self):
        """Returns the format string."""
        return self._fmt

    def __eq__(self, other):
        """Checks equality with another FormatString or a string."""
        if isinstance(other, FormatString):
            return self._fmt == other._fmt
        elif isinstance(other, str):
            return self._fmt == other
        return NotImplemented


class FileFormat(ABC):
    """Abstract base class for file format operations."""

    @abstractmethod
    def __init__(self, file: Union[str, "io.IOBase"]):
        """Initializes the file format with a file name or file object."""
        pass

    @abstractmethod
    def read(self, fmt: str, offset: Optional[int] = None) -> Any:
        """Reads data from the file according to the format string."""
        pass

    @abstractmethod
    def write(self, fmt: str, *values: Any, offset: Optional[int] = None) -> None:
        """Writes data to the file according to the format string."""
        pass

    @abstractmethod
    def getvalue(self) -> bytes:
        """Retrieves the file content as bytes."""
        pass

    @abstractmethod
    def close(self) -> None:
        """Closes the file."""
        pass

    # ... rest of the methods ...

    @abstractmethod
    def __enter__(self) -> "FileFormat":
        """Enters the context of the file format."""
        pass

    @abstractmethod
    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_value: Optional[BaseException],
        traceback: Optional[TracebackType],
    ) -> None:
        """Exits the context of the file format."""
        pass

    @abstractmethod
    def __iter__(self) -> Iterator:
        """Returns an iterator for the file format."""
        pass

    @abstractmethod
    def __next__(self) -> Any:
        """Advances to the next item of the file format."""
        pass


class AbstractFile(ABC):
    """Abstract base class for file operations."""

    @abstractmethod
    def read(self, length: int) -> bytes:
        pass

    @abstractmethod
    def write(self, data: bytes) -> None:
        pass


class BinaryFile(AbstractFile):
    """A class that represents a binary file."""

    def __init__(self, file_path: str):
        self.file_format = FileFormat(file_path)
        self.sections: List[BinaryFile.BinarySection] = []

    def read(self, length: int) -> bytes:
        return self.file_format.read(f"{length}s")

    def write(self, data: bytes) -> None:
        self.file_format.write(f"{len(data)}s", data)

    def add_section(self, offset: int, length: int) -> "BinaryFile.BinarySection":
        section = self.BinarySection(self.file_format, offset, length)
        self.sections.append(section)
        return section

    class BinarySection:
        def __init__(self, file_format: Type[FileFormat], offset: int, length: int):
            self.file_format = file_format
            self.offset = offset
            self.length = length

        def read(self, fmt: str, offset: Optional[int] = None) -> Any:
            absolute_offset = (self.offset + offset) if offset else self.offset
            return self.file_format.read(fmt, absolute_offset)


class VirtualFile(AbstractFile):
    """A class that represents a virtual file."""

    def __init__(self):
        self.buffer = bytearray()
        self.position = (
            0  # Pointer to keep track of the current position within the buffer
        )

    def write(self, data: Union[bytes, bytearray]) -> None:
        data_length = len(data)
        new_position = self.position + data_length

        # Extend buffer if necessary
        buffer_length = len(self.buffer)
        if new_position > buffer_length:
            self.buffer.extend(b"\x00" * (new_position - buffer_length))

        # Write data to buffer
        self.buffer[self.position : new_position] = data
        self.position = new_position

    def read(self, length: int) -> bytes:
        data = self.buffer[self.position : self.position + length]
        self.position += length
        return data

    def seek(self, offset: int, whence: Optional[int] = 0) -> None:
        if whence == 0:
            self.position = offset
        elif whence == 1:
            self.position += offset
        elif whence == 2:
            self.position = len(self.buffer) + offset
        else:
            raise ValueError("Invalid value for whence")

    def tell(self) -> int:
        return self.position

    def flush(self, file_path: str) -> None:
        with open(file_path, "wb") as file:
            file.write(self.buffer)

    def truncate(self, size: Optional[int] = None) -> None:
        if size is None:
            size = self.position
        self.buffer = self.buffer[:size]

    def getvalue(self) -> bytes:
        return bytes(self.buffer)


class FileFactory:
    """A factory class for creating file objects."""

    @staticmethod
    def create_file(file_type: Type[AbstractFile], *args, **kwargs) -> AbstractFile:
        return file_type(*args, **kwargs)

    @staticmethod
    def create_binary_file(file_path: str) -> BinaryFile:
        return BinaryFile(file_path)

    @staticmethod
    def create_virtual_file() -> VirtualFile:
        return VirtualFile()

    @staticmethod
    def create_file_from_bytes(data: bytes) -> VirtualFile:
        file = VirtualFile()
        file.write(data)
        return file

    @staticmethod
    def create_file_from_file(file_path: str) -> VirtualFile:
        with open(file_path, "rb") as file:
            return FileFactory.create_file_from_bytes(file.read())

    @staticmethod
    def create_file_from_file_format(file_format: FileFormat) -> VirtualFile:
        return FileFactory.create_file_from_bytes(file_format.getvalue())

    @staticmethod
    def create_file_from_binary_file(binary_file: BinaryFile) -> VirtualFile:
        return FileFactory.create_file_from_file_format(binary_file.file_format)

    @staticmethod
    def create_file_from_virtual_file(virtual_file: VirtualFile) -> VirtualFile:
        return FileFactory.create_file_from_file_format(virtual_file)

    @staticmethod
    def create_file_from_file_path(file_path: str) -> VirtualFile:
        return FileFactory.create_file_from_file(file_path)

    @staticmethod
    def create_file_from_file_object(file_object: io.IOBase) -> VirtualFile:
        if not file_object.readable():
            raise IOError("File object is not readable")
        return FileFactory.create_file_from_bytes(file_object.read())

    @staticmethod
    def create_file_from_file_like_object(file_like_object: io.BytesIO) -> VirtualFile:
        return FileFactory.create_file_from_file_object(file_like_object)


class AbstractIO(ABC):
    @abstractmethod
    @timeit
    def read(self, size: int) -> bytes:
        pass

    @abstractmethod
    @timeit
    def write(self, data: bytes) -> None:
        pass

    @abstractmethod
    def seek(self, offset: int, whence: int = 0) -> None:
        pass

    @abstractmethod
    def tell(self) -> int:
        pass

    @abstractmethod
    def close(self) -> None:
        pass


class BinaryStructIO(AbstractIO):
    def __init__(self, file_path: str, fmt: str, mode: str = "rb+"):
        self._file = open(file_path, mode)
        self.fmt = fmt
        self.formathandler = FormatString(fmt)
        self._struct = self.formathandler.struct

    @timeit
    def read(self, size: int = 0) -> bytes:
        if size == 0:
            return self._struct.unpack(self._file.read(self._struct.size))
        else:
            return self._struct.unpack(self._file.read(size))

    @timeit
    def write(self, data: bytes) -> None:
        if isinstance(data, bytes):
            self._file.write(data)
        else:
            self._file.write(self._struct.pack(data))

    def seek(self, offset: int, whence: int = 0) -> None:
        self._file.seek(offset, whence)

    def tell(self) -> int:
        return self._file.tell()

    def close(self) -> None:
        self._file.close()


class FileIO(AbstractIO):
    def __init__(self, file_path: str, mode: str = "rb+"):
        self._file = open(file_path, mode)

    @timeit
    def read(self, size: int) -> bytes:
        return self._file.read(size)

    @timeit
    def write(self, data: bytes) -> None:
        self._file.write(data)

    def seek(self, offset: int, whence: int = 0) -> None:
        self._file.seek(offset, whence)

    def tell(self) -> int:
        return self._file.tell()

    def close(self) -> None:
        self._file.close()


class MemoryIO(AbstractIO):
    def __init__(self, data: Union[bytes, bytearray, str] = b""):
        self._buffer = bytearray(data)
        self._position = 0
        self._memoryview = memoryview(self._buffer)

    @timeit
    def read(self, size: int) -> bytes:
        start = self._position
        end = self._position + size
        self._position = end
        return bytes(self._memoryview[start:end])

    def readinto(self, b: Union[bytearray, memoryview]) -> int:
        start = self._position
        end = self._position + len(b)
        self._position = end
        b[:] = self._memoryview[start:end]
        return len(b)

    def resize(self, size: int) -> None:
        self._buffer = self._buffer[:size]
        self._memoryview = memoryview(self._buffer)

    @timeit
    def write(self, data: bytes) -> None:
        data_len = len(data)
        end = self._position + data_len
        if end > len(self._buffer):
            self._memoryview.release()
            self._buffer.extend(b"\x00" * (end - len(self._buffer)))
            self._memoryview = memoryview(self._buffer)
        self._buffer[self._position : end] = data
        self._position = end

    def seek(self, offset: int, whence: int = 0) -> None:
        if whence == 0:
            self._position = offset
        elif whence == 1:
            self._position += offset
        elif whence == 2:
            self._position = len(self._buffer) + offset
        else:
            raise ValueError("Invalid value for whence")

    def tell(self) -> int:
        return self._position

    def close(self) -> None:
        pass


class MMapIO(AbstractIO):
    def __init__(self, file_path: str, mode: str = "r+b"):
        self._file = open(file_path, mode)
        self._mmap = mmap.mmap(self._file.fileno(), 0)

    @timeit
    def read(self, size: int) -> bytes:
        return self._mmap.read(size)

    @timeit
    def write(self, data: bytes) -> None:
        self._mmap.write(data)

    def seek(self, offset: int, whence: int = 0) -> None:
        self._mmap.seek(offset, whence)

    def tell(self) -> int:
        return self._mmap.tell()

    def close(self) -> None:
        self._mmap.close()
        self._file.close()


class BufferIO(AbstractIO):
    def __init__(self, data: bytes = b""):
        self._buffer = bytearray(data)
        self._position = 0

    @timeit
    def read(self, size: int) -> bytes:
        start = self._position
        end = self._position + size
        self._position = end
        return bytes(self._buffer[start:end])

    @timeit
    def write(self, data: bytes) -> None:
        data_len = len(data)
        end = self._position + data_len
        if end > len(self._buffer):
            self._buffer.extend(b"\x00" * (end - len(self._buffer)))
        self._buffer[self._position : end] = data
        self._position = end

    def seek(self, offset: int, whence: int = 0) -> None:
        if whence == 0:
            self._position = offset
        elif whence == 1:
            self._position += offset
        elif whence == 2:
            self._position = len(self._buffer) + offset
        else:
            raise ValueError("Invalid value for whence")

    def tell(self) -> int:
        return self._position

    def close(self) -> None:
        pass  # nothing to close in memory IO


class BytesIO(AbstractIO):
    def __init__(self, data: bytes = b""):
        self._buffer = io.BytesIO(data)
        self._position = 0

    @timeit
    def read(self, size: int) -> bytes:
        return self._buffer.read(size)

    @timeit
    def write(self, data: bytes) -> None:
        self._buffer.write(data)

    def seek(self, offset: int, whence: int = 0) -> None:
        self._buffer.seek(offset, whence)

    def tell(self) -> int:
        return self._buffer.tell()

    def close(self) -> None:
        self._buffer.close()

    def getvalue(self) -> bytes:
        return self._buffer.getvalue()

    def __enter__(self) -> "BytesIO":
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        self.close()

    def __iter__(self) -> Iterator:
        return self

    def __next__(self) -> Any:
        return self._buffer.__next__()

    def __repr__(self) -> str:
        return self._buffer.__repr__()


class UniversalIO:
    def __init__(self, source=None, io_type="file", mode="rb+", fmt=None):
        self.io_type = io_type
        self.mode = mode
        self.fmt = fmt
        self.formathandler = FormatString(fmt)

        if (
            io_type == "buffer"
            or io_type == "memory"
            or io_type == "bytes"
            and isinstance(source, str)
        ):
            source = open(source, "rb").read()
        if io_type == "file" and source is None:
            raise ValueError("File path is required for file IO")

        if io_type == "file":
            self.io_handler = FileIO(source, mode)
        elif io_type == "memory":
            self.io_handler = MemoryIO(source)
        elif io_type == "mmap":
            self.io_handler = MMapIO(source, mode)
        elif io_type == "struct":
            self.io_handler = BinaryStructIO(
                source, self.fmt if self.fmt else None, mode
            )
        elif io_type == "buffer":
            self.io_handler = BufferIO(source)
        elif io_type == "bytes":
            self.io_handler = BytesIO(source)
        else:
            raise ValueError(f"Unsupported IO type: {io_type}")

    def read(self, size: int) -> bytes:
        return self.io_handler.read(size)

    def write(self, data: bytes) -> None:
        self.io_handler.write(data)

    def seek(self, offset: int, whence: int = 0) -> None:
        self.io_handler.seek(offset, whence)

    def tell(self) -> int:
        return self.io_handler.tell()

    def close(self) -> None:
        self.io_handler.close()

    def getvalue(self) -> bytes:
        if hasattr(self.io_handler, "getvalue"):
            return self.io_handler.getvalue()
        else:
            raise NotImplementedError(
                f"{self.io_type} does not support getvalue method"
            )

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()


@KOLANICH
Copy link
Contributor

Well, I didnt touch ctypes anymore because it has no "proper" (convenient) way of handling endianess

Handling endianness is simple, if we assumme there are only 2 of them: if it is not the same as your hardware one - swap. In x86 CPUs since i486 there is a dedicated instruction for this, BSWAP. In modern CPUs starting from some editions of Pentium 4 it takes 1 uop, has latency and inverse throughput of 1 and occupies an ALU unit only. In other words: swapping endianness should be fast.

Currently im using ACCESS_COPY

I didn't mean that. I meant that by using

self.version = self.reader.read_value("u2le")

you kake a copy. When we read only seq-fields and pos-instances something like

inline uint16_t field(){
  return std::bit_cast<our_struct *>(raw_ptr_in_mmap)->field;
}

can be possible, it occupies 0 additional memory, we read the data directly from mapping on every access. When the data is BE, we just add __builtin_bswap16. This way it doesn't consume RAM except of the page frame used for mapping, all the data is on disk and is transparently read by the OS when needed, so we can work with a file larger than our physical RAM.

@Hypnootika
Copy link
Author

Hypnootika commented Oct 12, 2023

Yes, i understood later. I played around with following approach. That should do exactly what you mean:

class HeaderHandlerOptimized(MmapHandler):
    def __init__(self, file_path, mode='r'):
        super().__init__(file_path, mode)
        # Define the byte slices for each field in the header
        self.first_int_slice = slice(0, 4)
        self.second_short_slice = slice(4, 6)
        self.third_float_slice = slice(6, 10)
        self.fourth_float_slice = slice(10, 14)

    def get_first_int(self):
        """Return the raw byte slice for the first integer."""
        return self.mmapped_data[self.first_int_slice]

    def get_second_short(self):
        """Return the raw byte slice for the second short."""
        return self.mmapped_data[self.second_short_slice]

    def get_third_float(self):
        """Return the raw byte slice for the third float."""
        return self.mmapped_data[self.third_float_slice]

    def get_fourth_float(self):
        """Return the raw byte slice for the fourth float."""
        return self.mmapped_data[self.fourth_float_slice]

Which should also be easy to integrate

Bad thing is just that it limits the IO to mmap or memoryview (which looks insanely good after i figured out how to properly use it)

@Hypnootika
Copy link
Author

Hypnootika commented Oct 12, 2023

def measure_memoryview_handler():
    start = time.time()
    # Load the binary data into a memoryview
    with open('header_data_large.bin', 'rb') as f:
        header_data = memoryview(f.read())
    # Slice and interpret the first integer (4 bytes)
    first_int = header_data[0:4].cast('I')[0]
    end = time.time()
    return end - start

# Measure time for memoryview method with larger file
start_time = time.time()
for _ in range(iterations):
    measure_memoryview_handler()
time_memoryview_handler_large = (time.time() - start_time) / iterations

time_memoryview_handler_large

The above code with 1k iterations took

0.00024369192123413087

Which is 8 times faster than everything else i tested

@Hypnootika
Copy link
Author

Well, I didnt touch ctypes anymore because it has no "proper" (convenient) way of handling endianess

Handling endianness is simple, if we assumme there are only 2 of them: if it is not the same as your hardware one - swap. In x86 CPUs since i486 there is a dedicated instruction for this, BSWAP. In modern CPUs starting from some editions of Pentium 4 it takes 1 uop, has latency and inverse throughput of 1 and occupies an ALU unit only. In other words: swapping endianness should be fast.

Currently im using ACCESS_COPY

I didn't mean that. I meant that by using

self.version = self.reader.read_value("u2le")

you kake a copy. When we read only seq-fields and pos-instances something like

inline uint16_t field(){
  return std::bit_cast<our_struct *>(raw_ptr_in_mmap)->field;
}

can be possible, it occupies 0 additional memory, we read the data directly from mapping on every access. When the data is BE, we just add __builtin_bswap16. This way it doesn't consume RAM except of the page frame used for mapping, all the data is on disk and is transparently read by the OS when needed, so we can work with a file larger than our physical RAM.

Thanks for the Info by the way, really interesting!

@KOLANICH
Copy link
Contributor

If you want implement serialization respecting constraints, the things become tricky. We cannot write directly into the mmap - the writes can violate constraints and require recalc. We cannot do checks before writes - some changes require atomic and coherent changes in multiple places. We cannot and shouldn't dumbly do copy-on-write - we wanna minimize the amount written. In other words objects stop being proxies to raw memory and start have complex logic with memory allocation, deallocation, constraints checking, dependency tracking, dependent values recomputation, SMT- and other symbolic solving, moving memory areas, scheduling and maybe even bytecode. Some ideas about implementing that were described in kaitai-io/kaitai_struct#27 . Although none of it is required immediately and can be added gradually in indefinite future, it is important to remember that this can be needed in future, and so design the way that it would be possible to add it in future without complete redesign.

@Hypnootika
Copy link
Author

Hypnootika commented Oct 12, 2023

Considering your answer it looks like my Tests dont really help. Atleast this way.

@KOLANICH
atleast i tested your question:

would just passing a mmap object into KaitaiStream solve the issue?

Integration was a bit wonky but all in all its a really small change (doesnt even interfere with the rest)
unfortunately i couldnt solve this yet:

class KaitaiStream(object):
    def __init__(self, io: mmap.mmap):
        self._io = io

        print(type(self._io))
<class 'mmap.mmap'>
<class '_io.BytesIO'>
<class '_io.BytesIO'>
<class '_io.BytesIO'>
<class '_io.BytesIO'>```


Pretty interesting that the remaining code doesnt give a f* at all. 

@Hypnootika
Copy link
Author

If i can fix the rest, we can remove a lot of things that KaitaiStream creates, because mmap provides it:

image

@Hypnootika
Copy link
Author

Hypnootika commented Oct 19, 2023

Another experimental approach:

from contextlib import contextmanager
from dataclasses import dataclass, field as dc_field
import mmap
from struct import Struct
from typing import List


# Should consider using __slots__ for lightweight objects
@dataclass
class Field:
    fmt: str
    endianness: str

    @property
    def struct(self) -> Struct:
        return Struct(self.endianness + self.fmt)

    @property
    def size(self) -> int:
        return self.struct.size

    def pack(self, *args) -> bytes:
        return self.struct.pack(*args)

    def unpack(self, data: bytes) -> typing.Tuple:
        return self.struct.unpack(data)


# Consider using __slots__ for lightweight callable objects
class FieldCallable:
    def __init__(self, _field: Field):
        self.field = _field
        self.fmt = _field.fmt
        self.size = _field.size
        self.endianness = _field.endianness

    def __call__(self):
        return self.field


# Custom metaclass approach works, but might be overkill
class DataStructMeta(type):
    def __new__(cls, _name, bases, dct):
        fields = dct.get('_fields_', [])
        for field_name, _fmt, _endianness in fields:
            dct[field_name] = Field(_fmt, _endianness)
        return super().__new__(cls, _name, bases, dct)


@dataclass
class Structure:
    fields: List[Field] = dc_field(default_factory = list)

    def __post_init__(self):
        self.size = sum(f.size for f in self.fields)

    def pack(self, *args) -> bytes:
        return b"".join(f.pack(*args) for f in self.fields)

    def unpack(self, data: bytes) -> typing.Tuple:
        return tuple(f.unpack(data[i:i + f.size]) for i, f in enumerate(self.fields))


# Should consider using __slots__ for lightweight objects
class DataStruct(Structure):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.fields = [
            Field("B", "<"),
            Field("H", "<"),
            Field("H", ">"),
            Field("I", "<"),
            Field("I", ">"),
            Field("Q", "<"),
            Field("Q", ">"),
            Field("b", "<"),
            Field("h", "<"),
            Field("h", ">"),
            Field("i", "<"),
            Field("i", ">"),
            Field("q", "<"),
            Field("q", ">"),
            Field("f", "<"),
            Field("f", ">"),
            Field("d", "<"),
            Field("d", ">"),
        ]


(u1, u2le, u2be, u4le, u4be, u8le, u8be, s1, s2le,
 s2be, s4le, s4be, s8le, s8be, f4le, f4be, f8le, f8be) = map(FieldCallable, DataStruct().fields)


# Should consider using __slots__ for lightweight objects
class Stream:
    def __init__(self, filename: str, structure: Structure):
        self.filename, self.structure = filename, structure
        self.mmap: typing.Optional[mmap.mmap] = None

    def _create_mmap(self) -> None:
        try:
            with open(self.filename, "r+b") as _f:
                self.mmap = mmap.mmap(_f.fileno(), self.structure.size, access = mmap.ACCESS_WRITE)
        except FileNotFoundError:
            raise FileNotFoundError(f"File {self.filename} not found")


    @classmethod
    @contextmanager
    def from_structure(cls, filename: str, structure: Structure):
        _stream = cls(filename, structure)
        _stream._create_mmap()
        try:
            yield _stream
        finally:
            _stream.mmap.close()

    def __getitem__(self, key: typing.Union[int, slice]) -> typing.Union[bytes, bytearray]:
        if isinstance(key, int):
            return self.mmap[key:key + 1]
        elif isinstance(key, slice):
            return self.mmap[key.start:key.stop:key.step]
        else:
            raise TypeError("Invalid argument type")

    def __setitem__(self, key: typing.Union[int, slice], value: typing.Union[bytes, bytearray]):
        if isinstance(key, int):
            self.mmap[key:key + 1] = value
        elif isinstance(key, slice):
            self.mmap[key.start:key.stop:key.step] = value
        else:
            raise TypeError("Invalid argument type")

    @property
    def size(self) -> int:
        return self.structure.size

    def read(self, offset: int, size: int) -> bytes:
        return self.mmap[offset:offset + size]

    def write(self, offset: int, data: bytes):
        self.mmap[offset:offset + len(data)] = data

    def read_struct(self, offset: int, structure: Structure) -> typing.Tuple:
        return structure.unpack(self.read(offset, structure.size))

    def write_struct(self, offset: int, structure: Structure, data: typing.Tuple):
        self.write(offset, structure.pack(*data))
    # Consider using __slots__


if __name__ == "__main__":
    file = "test.anft"

    struct = Structure(fields = [u1, u2le, u2be, u4le, u4be, u8le, u8be, s1, s2le, s2be, s4le, s4be, s8le, s8be, f4le, f4be, f8le, f8be,
                                 u1, u2le, u2be, u4le, u4be, u8le, u8be, s1, s2le, s2be, s4le, s4be, s8le, s8be, f4le, f4be, f8le, f8be])
    with Stream.from_structure(file, struct) as stream:  # Changed 's' to 'struct'
        for i, field in enumerate(struct.fields):
            print(f"{field.fmt} {field.endianness} {field.size}, {stream.size}, {stream.mmap.size()}, {stream.structure.size}")
            print(stream.read(i, field.size))

@KOLANICH, this is nearly what you wanted. The mmap itself is only as big in memory as the structured data given.

Next step would be avoiding to mmap and just to collect a dict or something of size + offset, for an abstract IO

@Hypnootika
Copy link
Author

I actually have a working memoryview KaitaiStruct embracing zerocopy and its init and io is 40% faster than the current implementation BUT:

Achieving this need a nearly complete rewrite of the python runtime and of the generated Classes(less important, gain was only like 10%)

Long story short, case closed for me

@Hypnootika
Copy link
Author

Hypnootika commented Oct 23, 2023

Edit: forget about it, its nonsense

So this topic didn't let me be happy. (It became a personal topic, dont worry, i made this for me :D)

Good News:

  • Zero Copy Streams with the same current functionality is easy to implement. The Problem is only, that as soon as a Instance reads, its obviously going up.

So i tried a lot of methods to create some kind of "offset-size" mapping.

The idea would be, that (in this example):
self.index_map = (ctypes.c_uint64 * (self.size // 32))()
The above, with some changes, could on _read() call of instances, instead of copying by reading, only add its mapping to the Array. If someone then actually wants to print the data (or whatever), we traverse the array and create views on that piece of data.

This way it should be possible to avoid most memory alloc.

Would love some feedback (not in terms of implementing it into Kaitai but rather what you think about the method)

class MIO_IO(Cursor):
    def __init__(self, io_stream: Cursor) -> None:
        self._io: Cursor = io_stream
        self.size = os.path.getsize(r"D:\Dev\MIOStream\MIOStream\files\test.dat")
        self.index_map = (ctypes.c_uint64 * (self.size // 32))()

    @classmethod
    def from_file(cls, file: str) -> "MIO_IO":
        with open(file, 'rb') as f:
            m: Cursor = Cursor(mmap.mmap(f.fileno(), 0, access = mmap.ACCESS_READ))
        return cls(m)

    def create_index(self, chunk_size: int):
        chunks = self.size // chunk_size
        self.index_map = (ctypes.c_uint64 * chunks)()
        for i in range(chunks):
            self.index_map[i] = self._io.tell()
            self._io.seek(chunk_size, 1)


def main():
    file_path = r"D:\Dev\MIOStream\MIOStream\files\test.dat"
    M = MIO_IO.from_file(file_path)
    M.create_index(32)

image

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    41     24.0 MiB     24.0 MiB           1   def main():
    42     24.0 MiB      0.0 MiB           1       file_path = r"D:\Dev\MIOStream\MIOStream\files\test.dat"
    43    184.1 MiB    160.1 MiB           1       M = MIO_IO.from_file(file_path)
    44    184.1 MiB      0.0 MiB           1       M.create_index(32)


Mon Oct 23 20:35:42 2023    stats

         41952050 function calls in 6.750 seconds

   Ordered by: cumulative time
   List reduced from 14 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    6.750    6.750 {built-in method builtins.exec}
        1    0.004    0.004    6.750    6.750 <string>:1(<module>)
        1    0.000    0.000    6.746    6.746 D:\Dev\MIOStream\MIOStream\mio.py:41(main)
        1    4.571    4.571    6.737    6.737 D:\Dev\MIOStream\MIOStream\mio.py:25(create_index)
 20976019    1.408    0.000    1.408    0.000 {method 'seek' of 'iocursor.cursor.Cursor' objects}
 20976019    0.758    0.000    0.758    0.000 {method 'tell' of 'iocursor.cursor.Cursor' objects}
        1    0.000    0.000    0.010    0.010 D:\Dev\MIOStream\MIOStream\mio.py:19(from_file)
        1    0.009    0.009    0.010    0.010 D:\Dev\MIOStream\MIOStream\mio.py:14(__init__)
        1    0.000    0.000    0.000    0.000 {built-in method io.open}
        1    0.000    0.000    0.000    0.000 <frozen genericpath>:48(getsize)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants