diff --git a/doc/source/index.rst b/doc/source/index.rst index d37d8c3..cf4b3c2 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -59,6 +59,8 @@ API :members: .. autoclass:: LRU :members: +.. autoclass:: SharedMemory + :members: .. autoclass:: Sieve :members: .. autoclass:: Zip diff --git a/zict/__init__.py b/zict/__init__.py index fcdc44b..1d79ec2 100644 --- a/zict/__init__.py +++ b/zict/__init__.py @@ -6,6 +6,7 @@ from zict.func import Func as Func from zict.lmdb import LMDB as LMDB from zict.lru import LRU as LRU +from zict.shared_memory import SharedMemory as SharedMemory from zict.sieve import Sieve as Sieve from zict.utils import InsertionSortedSet as InsertionSortedSet from zict.zip import Zip as Zip diff --git a/zict/shared_memory/__init__.py b/zict/shared_memory/__init__.py new file mode 100644 index 0000000..7d03b6e --- /dev/null +++ b/zict/shared_memory/__init__.py @@ -0,0 +1 @@ +from zict.shared_memory.shared_memory import SharedMemory diff --git a/zict/shared_memory/_linux.py b/zict/shared_memory/_linux.py new file mode 100644 index 0000000..29b4394 --- /dev/null +++ b/zict/shared_memory/_linux.py @@ -0,0 +1,62 @@ +"""Linux implementation of :class:`zict.SharedMemory`. + +Wraps around glibc ``memfd_create``. +""" +from __future__ import annotations + +import ctypes +import mmap +import os +from collections.abc import Iterable + +_memfd_create = None + + +def _setitem(safe_key: str, value: Iterable[bytes | bytearray | memoryview]) -> int: + global _memfd_create + if _memfd_create is None: + libc = ctypes.CDLL("libc.so.6") + _memfd_create = libc.memfd_create + + fd = _memfd_create(safe_key.encode("ascii"), 0) + if fd == -1: + raise OSError("Call to memfd_create failed") # pragma: nocover + + with os.fdopen(fd, "wb", closefd=False) as fh: + fh.writelines(value) + + return fd + + +def _getitem(fd: int) -> memoryview: + # This opens a second fd for as long as the memory map is referenced. + # Sadly there does not seem a way to extract the fd from the mmap, so we have to + # keep the original fd open for the purpose of exporting. + return memoryview(mmap.mmap(fd, 0)) + + +def _delitem(fd: int) -> None: + # Close the original fd. There may be other fd's still open if the shared memory is + # referenced somewhere else. + # This is also called by SharedMemory.__del__. + os.close(fd) + + +def _export(safe_key: str, fd: int) -> tuple: + return safe_key, os.getpid(), fd + + +def _import(safe_key: str, pid: int, fd: int) -> int: + # if fd has been closed, raise FileNotFoundError + # if fd has been closed and reopened to something else, this may also raise a + # generic OSError, e.g. if this is now a socket + new_fd = os.open(f"/proc/{pid}/fd/{fd}", os.O_RDWR) + + expect = f"/memfd:{safe_key} (deleted)" + actual = os.readlink(f"/proc/{os.getpid()}/fd/{new_fd}") + if actual != expect: + # fd has been closed and reopened to something else + os.close(new_fd) + raise OSError() + + return new_fd diff --git a/zict/shared_memory/_windows.py b/zict/shared_memory/_windows.py new file mode 100644 index 0000000..57e4181 --- /dev/null +++ b/zict/shared_memory/_windows.py @@ -0,0 +1,51 @@ +"""Windows implementation of :class:`zict.SharedMemory`. + +Conveniently, :class:`multiprocessing.shared_memory.SharedMemory` already wraps around +the Windows API we want to use, so this is implemented as a hack on top of it. +""" +from __future__ import annotations + +import mmap +import multiprocessing.shared_memory +from collections.abc import Collection +from typing import cast + + +class _PySharedMemoryNoClose(multiprocessing.shared_memory.SharedMemory): + def __del__(self) -> None: + pass + + +def _setitem( + safe_key: str, value: Collection[bytes | bytearray | memoryview] +) -> memoryview: + nbytes = sum(v.nbytes if isinstance(v, memoryview) else len(v) for v in value) + shm = _PySharedMemoryNoClose(safe_key, create=True, size=nbytes) + mm = cast(mmap.mmap, shm.buf.obj) + for v in value: + mm.write(v) + # This dereferences shm; if we hadn't overridden the __del__ method, it would cause + # it to automatically close the memory map and deallocate the shared memory. + return shm.buf + + +def _getitem(mm: memoryview) -> memoryview: + # Nothing to do. This is just for compatibility with the Linux implementation, which + # instead creates a memory map on the fly. + return mm + + +def _delitem(mm: memoryview) -> None: + # Nothing to do. The shared memory is released as soon as the last memory map + # referencing it is destroyed. + pass + + +def _export(safe_key: str, mm: memoryview) -> tuple: + return (safe_key,) + + +def _import(safe_key: str) -> memoryview: + # Raises OSError in case of invalid key + shm = _PySharedMemoryNoClose(safe_key) + return shm.buf diff --git a/zict/shared_memory/shared_memory.py b/zict/shared_memory/shared_memory.py new file mode 100644 index 0000000..4e61d7b --- /dev/null +++ b/zict/shared_memory/shared_memory.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import secrets +import sys +from collections.abc import Iterator, KeysView +from typing import Any +from urllib.parse import quote, unquote + +from zict.common import ZictBase + +if sys.platform == "linux": + from zict.shared_memory._linux import _delitem, _export, _getitem, _import, _setitem +elif sys.platform == "win32": + from zict.shared_memory._windows import ( + _delitem, + _export, + _getitem, + _import, + _setitem, + ) + + +class SharedMemory(ZictBase[str, memoryview]): + """Mutable Mapping interface to shared memory. + + **Supported OSs:** Linux, Windows + + Keys must be strings, values must be buffers. + Keys are stored in private memory, and other SharedMemory objects by default won't + see them - even in case of key collision, the two pieces of data remain separate. + + In order to share the same buffer, one SharedMemory object must call + :meth:`export` and the other :meth:`import_`. + + **Resources usage** + + On Linux, you will hold 1 file descriptor open for every key in the SharedMemory + mapping, plus 1 file descriptor for every returned memoryview that is referenced + somewhere else. Please ensure that your ``ulimit`` is high enough to cope with this. + + If you expect to call ``__getitem__`` multiple times on the same key while the + return value from the previous call is still in use, you should wrap this mapping in + a :class:`~zict.Cache`: + + >>> import zict + >>> shm = zict.Cache( + ... zict.SharedMemory(), + ... zict.WeakValueMapping(), + ... update_on_set=False, + ... ) # doctest: +SKIP + + The above will cap the amount of open file descriptors per key to 2. + + **Lifecycle** + + Memory is released when all the SharedMemory objects that were sharing the key have + deleted it *and* the buffer returned by ``__getitem__`` is no longer referenced + anywhere else. + Process termination, including ungraceful termination (SIGKILL, SIGSEGV), also + releases the memory; in other words you don't risk leaking memory to the + OS if all processes that were sharing it crash or are killed. + + Examples + -------- + In process 1: + + >>> import pickle, numpy, zict # doctest: +SKIP + >>> shm = zict.SharedMemory() # doctest: +SKIP + >>> a = numpy.random.random(2**27) # 1 GiB # doctest: +SKIP + >>> buffers = [] # doctest: +SKIP + >>> pik = pickle.dumps(a, protocol=5, buffer_callback=buffers.append) + ... # doctest: +SKIP + >>> # This deep-copies the buffer, resulting in 1 GiB private + 1 GiB shared memory. + >>> shm["a"] = buffers # doctest: +SKIP + >>> # Release private memory, leaving only the shared memory allocated + >>> del a, buffers # doctest: +SKIP + >>> # Recreate array from shared memory. This requires no extra memory. + >>> a = pickle.loads(pik, buffers=[shm["a"]]) # doctest: +SKIP + >>> # Send trivially-sized metadata (<1 kiB) to the peer process somehow. + >>> send_to_process_2((pik, shm.export("a"))) # doctest: +SKIP + + In process 2: + + >>> import pickle, zict # doctest: +SKIP + >>> shm = zict.SharedMemory() # doctest: +SKIP + >>> pik, metadata = receive_from_process_1() # doctest: +SKIP + >>> key = shm.import_(metadata) # returns "a" # doctest: +SKIP + >>> a = pickle.loads(pik, buffers=[shm[key]]) # doctest: +SKIP + + Now process 1 and 2 hold a reference to the same memory; in-place changes on one + process are reflected onto the other. The shared memory is released after you delete + the key and dereference the buffer returned by ``__getitem__`` on *both* processes: + + >>> del shm["a"] # doctest: +SKIP + >>> del a # doctest: +SKIP + + or alternatively when both processes are terminated. + + **Implementation notes** + + This mapping uses OS-specific shared memory, which + + 1. can be shared among already existing processes, e.g. unlike ``mmap(fd=-1)``, and + 2. is automatically cleaned up by the OS in case of ungraceful process termination, + e.g. unlike ``shm_open`` (which is used by :mod:`multiprocessing.shared_memory` + on all POSIX OS'es) + + It is implemented on top of ``memfd_create`` on Linux and ``CreateFileMapping`` on + Windows. Notably, there is no POSIX equivalent for these API calls, as it only + implements ``shm_open`` which would inevitably cause memory leaks in case of + ungraceful process termination. + """ + + # {key: (unique safe key, implementation-specific data)} + _data: dict[str, tuple[str, Any]] + + def __init__(self): # type: ignore[no-untyped-def] + if sys.platform not in ("linux", "win32"): + raise NotImplementedError( + "SharedMemory is only available on Linux and Windows" + ) + + self._data = {} + + def __str__(self) -> str: + return f"" + + __repr__ = __str__ + + def __setitem__( + self, + key: str, + value: bytes + | bytearray + | memoryview + | list[bytes | bytearray | memoryview] + | tuple[bytes | bytearray | memoryview, ...], + ) -> None: + try: + del self[key] + except KeyError: + pass + + if not isinstance(value, (tuple, list)): + value = [value] + safe_key = quote(key, safe="") + "#" + secrets.token_bytes(8).hex() + impl_data = _setitem(safe_key, value) + self._data[key] = safe_key, impl_data + + def __getitem__(self, key: str) -> memoryview: + _, impl_data = self._data[key] + return _getitem(impl_data) + + def __delitem__(self, key: str) -> None: + _, impl_data = self._data.pop(key) + _delitem(impl_data) + + def __del__(self) -> None: + try: + data_values = self._data.values() + except Exception: + # Interpreter shutdown + return # pragma: nocover + + for _, impl_data in data_values: + try: + _delitem(impl_data) + except Exception: + pass # pragma: nocover + + def close(self) -> None: + # Implements ZictBase.close(). Also triggered by __exit__. + self.clear() + + def __contains__(self, key: object) -> bool: + return key in self._data + + def keys(self) -> KeysView[str]: + return self._data.keys() + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def export(self, key: str) -> tuple: + """Export metadata for a key, which can be fed into :meth:`import_` on + another process. + + Returns + ------- + Opaque metadata object (implementation-specific) to be passed to + :meth:`import_`. It is serializable with JSON, YAML, and msgpack. + + See Also + -------- + import_ + """ + return _export(*self._data[key]) + + def import_(self, metadata: tuple | list) -> str: + """Import a key from another process, starting to share the memory area. + + You should treat parameters as implementation details and just unpack the tuple + that was generated by :meth:`export`. + + Returns + ------- + Key that was just added to the mapping + + Raises + ------ + FileNotFoundError + Either the key or the whole SharedMemory object were deleted on the process + where you ran :meth:`export`, or the process was terminated. + + Notes + ----- + On Windows, this method will raise FileNotFoundError if the key has been deleted + from the other SharedMemory mapping *and* it is no longer referenced anywhere. + On Linux, this method will raise as soon as the key is deleted from the other + SharedMemory mapping, even if it's still referenced. + + e.g. this code is not portable, as it will work on Windows but not on Linux: + + >>> buf = shm["x"] = buf # doctest: +SKIP + >>> meta = shm.export("x") # doctest: +SKIP + >>> del shm["x"] # doctest: +SKIP + + See Also + -------- + export + """ + safe_key = metadata[0] + key = unquote(safe_key.split("#")[0]) + + try: + del self[key] + except KeyError: + pass + + try: + impl_data = _import(*metadata) + except OSError: + raise FileNotFoundError(f"Peer process no longer holds the key: {key!r}") + self._data[key] = safe_key, impl_data + return key