Skip to content

Commit

Permalink
jsoncomm: transparently handle huge messages
Browse files Browse the repository at this point in the history
And drop the workarounds from osbuild#824,osbuild#1331,osbuild#1836
  • Loading branch information
mvo5 committed Aug 7, 2024
1 parent ae72480 commit 0132220
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 39 deletions.
18 changes: 3 additions & 15 deletions osbuild/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]:
}
}

with make_args_file(store.tmp, args) as fd:
fds = [fd]
client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
reply, _ = client.call_with_fds("map", {}, fds)
client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
reply = client.call("map", args)

path = reply["path"]

Expand All @@ -105,14 +103,6 @@ def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]:
return reply


@contextlib.contextmanager
def make_args_file(tmp, args):
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f:
json.dump(args, f)
f.seek(0)
yield f.fileno()


class InputService(host.Service):
"""Input host service"""

Expand All @@ -126,10 +116,8 @@ def unmap(self):
def stop(self):
self.unmap()

def dispatch(self, method: str, _, _fds):
def dispatch(self, method: str, args, fds):
if method == "map":
with os.fdopen(_fds.steal(0)) as f:
args = json.load(f)
store = StoreClient(connect_to=args["api"]["store"])
r = self.map(store,
args["origin"],
Expand Down
21 changes: 3 additions & 18 deletions osbuild/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def download(self, mgr: host.ServiceManager, store: ObjectStore, libdir: PathLik
cache = os.path.join(store.store, "sources")

args = {
"items": self.items,
"options": self.options,
"cache": cache,
"output": None,
Expand All @@ -38,20 +39,10 @@ def download(self, mgr: host.ServiceManager, store: ObjectStore, libdir: PathLik
}

client = mgr.start(f"source/{source}", self.info.path)

with self.make_items_file(store.tmp) as fd:
fds = [fd]
reply = client.call_with_fds("download", args, fds)
reply = client.call("download", args)

return reply

@contextlib.contextmanager
def make_items_file(self, tmp):
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f:
json.dump(self.items, f)
f.seek(0)
yield f.fileno()

# "name", "id", "stages", "results" is only here to make it looks like a
# pipeline for the monitor. This should be revisited at some point
# and maybe the monitor should get first-class support for
Expand Down Expand Up @@ -105,12 +96,6 @@ def exists(self, checksum, _desc) -> bool:
"""Returns True if the item to download is in cache. """
return os.path.isfile(f"{self.cache}/{checksum}")

@staticmethod
def load_items(fds):
with os.fdopen(fds.steal(0)) as f:
items = json.load(f)
return items

def setup(self, args):
self.cache = os.path.join(args["cache"], self.content_type)
os.makedirs(self.cache, exist_ok=True)
Expand All @@ -120,7 +105,7 @@ def dispatch(self, method: str, args, fds):
if method == "download":
self.setup(args)
with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir:
self.fetch_all(SourceService.load_items(fds))
self.fetch_all(args["items"])
return None, None

raise host.ProtocolError("Unknown method")
19 changes: 13 additions & 6 deletions osbuild/util/jsoncomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def recv(self):
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
assert len(data) % fds.itemsize == 0
fds.frombytes(data)
fdset = FdSet(rawfds=fds)
fd_payload = fds[0]
fdset = FdSet(rawfds=fds[1:])

# Check the returned message flags. If the message was truncated, we
# have to discard it. This shouldn't happen, but there is no harm in
Expand All @@ -364,7 +365,8 @@ def recv(self):
raise BufferError

try:
payload = json.loads(msg[0])
with os.fdopen(fd_payload) as f:
payload = json.loads(f.read())
except json.JSONDecodeError as e:
raise BufferError from e

Expand Down Expand Up @@ -399,13 +401,18 @@ def send(self, payload: object, *, fds: Optional[list] = None):
if not self._socket:
raise RuntimeError("Tried to send without socket.")

serialized = json.dumps(payload).encode()
cmsg = []
serialized = json.dumps(payload)
fd_payload = os.memfd_create("jsoncomm/payload", 0)
with os.fdopen(os.dup(fd_payload), "w") as f:
f.write(serialized)
f.seek(0)
all_fds = [fd_payload]
if fds:
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
all_fds += fds
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", all_fds)))

n = self._socket.sendmsg([serialized], cmsg, 0)
assert n == len(serialized)
self._socket.sendmsg([b"{}"], cmsg, 0)

def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
"""Send a message and wait for a reply
Expand Down

0 comments on commit 0132220

Please sign in to comment.