diff --git a/shared_memory_dict/dict.py b/shared_memory_dict/dict.py index ba01f95..1eb6202 100644 --- a/shared_memory_dict/dict.py +++ b/shared_memory_dict/dict.py @@ -1,4 +1,5 @@ import logging +import os import sys import warnings from contextlib import contextmanager @@ -159,11 +160,29 @@ def _get_or_create_memory_block( try: return SharedMemory(name=name) except FileNotFoundError: + self.check_security(name) shm = SharedMemory(name=name, create=True, size=size) data = self._serializer.dumps({}) shm.buf[: len(data)] = data return shm + def check_security(self, name: str) -> None: + """Check if shared memory belongs to the current user and is only read+writeable for us""" + if os.name == 'nt': + return + + if '/' in name: + raise TypeError('Name must not contain "/".') + + shm_file = os.path.join('/dev/shm', name) + stat = os.stat(shm_file) + if ( + stat.st_uid != os.getuid() + or stat.st_gid != os.getgid() + or stat.st_mode != 0o100600 + ): + os.unlink(shm_file) + def _save_memory(self, db: Dict[str, Any]) -> None: data = self._serializer.dumps(db) try: