From 0132220960804f99dae20156c675d77c1728ae7b Mon Sep 17 00:00:00 2001 From: Michael Vogt Date: Wed, 7 Aug 2024 16:49:30 +0200 Subject: [PATCH] jsoncomm: transparently handle huge messages And drop the workarounds from #824,#1331,#1836 --- osbuild/inputs.py | 18 +++--------------- osbuild/sources.py | 21 +++------------------ osbuild/util/jsoncomm.py | 19 +++++++++++++------ 3 files changed, 19 insertions(+), 39 deletions(-) diff --git a/osbuild/inputs.py b/osbuild/inputs.py index 229c12641..9d37b1e18 100644 --- a/osbuild/inputs.py +++ b/osbuild/inputs.py @@ -88,10 +88,8 @@ def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]: } } - with make_args_file(store.tmp, args) as fd: - fds = [fd] - client = self.service_manager.start(f"input/{ip.name}", ip.info.path) - reply, _ = client.call_with_fds("map", {}, fds) + client = self.service_manager.start(f"input/{ip.name}", ip.info.path) + reply = client.call("map", args) path = reply["path"] @@ -105,14 +103,6 @@ def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]: return reply -@contextlib.contextmanager -def make_args_file(tmp, args): - with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f: - json.dump(args, f) - f.seek(0) - yield f.fileno() - - class InputService(host.Service): """Input host service""" @@ -126,10 +116,8 @@ def unmap(self): def stop(self): self.unmap() - def dispatch(self, method: str, _, _fds): + def dispatch(self, method: str, args, fds): if method == "map": - with os.fdopen(_fds.steal(0)) as f: - args = json.load(f) store = StoreClient(connect_to=args["api"]["store"]) r = self.map(store, args["origin"], diff --git a/osbuild/sources.py b/osbuild/sources.py index a3f532e42..5dfb8f58c 100644 --- a/osbuild/sources.py +++ b/osbuild/sources.py @@ -30,6 +30,7 @@ def download(self, mgr: host.ServiceManager, store: ObjectStore, libdir: PathLik cache = os.path.join(store.store, "sources") args = { + "items": self.items, "options": self.options, "cache": cache, "output": None, @@ -38,20 +39,10 @@ def download(self, mgr: host.ServiceManager, store: ObjectStore, libdir: PathLik } client = mgr.start(f"source/{source}", self.info.path) - - with self.make_items_file(store.tmp) as fd: - fds = [fd] - reply = client.call_with_fds("download", args, fds) + reply = client.call("download", args) return reply - @contextlib.contextmanager - def make_items_file(self, tmp): - with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f: - json.dump(self.items, f) - f.seek(0) - yield f.fileno() - # "name", "id", "stages", "results" is only here to make it looks like a # pipeline for the monitor. This should be revisited at some point # and maybe the monitor should get first-class support for @@ -105,12 +96,6 @@ def exists(self, checksum, _desc) -> bool: """Returns True if the item to download is in cache. """ return os.path.isfile(f"{self.cache}/{checksum}") - @staticmethod - def load_items(fds): - with os.fdopen(fds.steal(0)) as f: - items = json.load(f) - return items - def setup(self, args): self.cache = os.path.join(args["cache"], self.content_type) os.makedirs(self.cache, exist_ok=True) @@ -120,7 +105,7 @@ def dispatch(self, method: str, args, fds): if method == "download": self.setup(args) with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir: - self.fetch_all(SourceService.load_items(fds)) + self.fetch_all(args["items"]) return None, None raise host.ProtocolError("Unknown method") diff --git a/osbuild/util/jsoncomm.py b/osbuild/util/jsoncomm.py index e31a49ca7..d89ca4335 100644 --- a/osbuild/util/jsoncomm.py +++ b/osbuild/util/jsoncomm.py @@ -353,7 +353,8 @@ def recv(self): if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS: assert len(data) % fds.itemsize == 0 fds.frombytes(data) - fdset = FdSet(rawfds=fds) + fd_payload = fds[0] + fdset = FdSet(rawfds=fds[1:]) # Check the returned message flags. If the message was truncated, we # have to discard it. This shouldn't happen, but there is no harm in @@ -364,7 +365,8 @@ def recv(self): raise BufferError try: - payload = json.loads(msg[0]) + with os.fdopen(fd_payload) as f: + payload = json.loads(f.read()) except json.JSONDecodeError as e: raise BufferError from e @@ -399,13 +401,18 @@ def send(self, payload: object, *, fds: Optional[list] = None): if not self._socket: raise RuntimeError("Tried to send without socket.") - serialized = json.dumps(payload).encode() cmsg = [] + serialized = json.dumps(payload) + fd_payload = os.memfd_create("jsoncomm/payload", 0) + with os.fdopen(os.dup(fd_payload), "w") as f: + f.write(serialized) + f.seek(0) + all_fds = [fd_payload] if fds: - cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))) + all_fds += fds + cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", all_fds))) - n = self._socket.sendmsg([serialized], cmsg, 0) - assert n == len(serialized) + self._socket.sendmsg([b"{}"], cmsg, 0) def send_and_recv(self, payload: object, *, fds: Optional[list] = None): """Send a message and wait for a reply