Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static msg types using msgspec #311

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tractor/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,18 @@ async def started(
await self.chan.send({'started': value, 'cid': self.cid})
self._started_called = True


# TODO: msg capability context api1
# @acm
# async def enable_msg_caps(
# self,
# msg_subtypes: Union[
# list[list[Struct]],
# Protocol, # hypothetical type that wraps a msg set
# ],
# ) -> tuple[Callable, Callable]: # payload enc, dec pair
# ...

# TODO: do we need a restart api?
# async def restart(self) -> None:
# pass
Expand Down
159 changes: 158 additions & 1 deletion tractor/msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,19 @@
# - https://jcristharif.com/msgspec/api.html#struct
# - https://jcristharif.com/msgspec/extending.html
# via ``msgpack-python``:
# - https://github.com/msgpack/msgpack-python#packingunpacking-of-custom-data-type
# https://github.com/msgpack/msgpack-python#packingunpacking-of-custom-data-type

from __future__ import annotations
from contextlib import contextmanager as cm
from pkgutil import resolve_name
from typing import Union, Any, Optional


from msgspec import Struct, Raw
from msgspec.msgpack import (
Encoder,
Decoder,
)


class NamespacePath(str):
Expand Down Expand Up @@ -78,3 +87,151 @@ def from_ref(
(ref.__module__,
getattr(ref, '__name__', ''))
))


# LIFO codec stack that is appended when the user opens the
# ``configure_native_msgs()`` cm below to configure a new codec set
# which will be applied to all new (msgspec relevant) IPC transports
# that are spawned **after** the configure call is made.
_lifo_codecs: list[
tuple[
Encoder,
Decoder,
],
] = [(Encoder(), Decoder())]


def get_msg_codecs() -> tuple[
Encoder,
Decoder,
]:
'''
Return the currently configured ``msgspec`` codec set.

The defaults are defined above.

'''
global _lifo_codecs
return _lifo_codecs[-1]


@cm
def configure_native_msgs(
tagged_structs: list[Struct],
):
'''
Push a codec set that will natively decode
tagged structs provied in ``tagged_structs``
in all IPC transports and pop the codec on exit.

'''
# See "tagged unions" docs:
# https://jcristharif.com/msgspec/structs.html#tagged-unions

# "The quickest way to enable tagged unions is to set tag=True when
# defining every struct type in the union. In this case tag_field
# defaults to "type", and tag defaults to the struct class name
# (e.g. "Get")."
enc = Encoder()

types_union = Union[tagged_structs[0]] | Any
for struct in tagged_structs[1:]:
types_union |= struct

dec = Decoder(types_union)

_lifo_codecs.append((enc, dec))
try:
print("YOYOYOOYOYOYOY")
yield enc, dec
finally:
print("NONONONONON")
_lifo_codecs.pop()


class Header(Struct, tag=True):
'''
A msg header which defines payload properties

'''
uid: str
msgtype: Optional[str] = None


class Msg(Struct, tag=True):
'''
The "god" msg type, a box for task level msg types.

'''
header: Header
payload: Raw


_root_dec = Decoder(Msg)
_root_enc = Encoder()

# sub-decoders for retreiving embedded
# payload data and decoding to a sender
# side defined (struct) type.
_subdecs: dict[
Optional[str],
Decoder] = {
None: Decoder(Any),
}


@cm
def enable_context(
msg_subtypes: list[list[Struct]]
) -> Decoder:

for types in msg_subtypes:
first = types[0]

# register using the default tag_field of "type"
# which seems to map to the class "name".
tags = [first.__name__]

# create a tagged union decoder for this type set
type_union = Union[first]
for typ in types[1:]:
type_union |= typ
tags.append(typ.__name__)

dec = Decoder(type_union)

# register all tags for this union sub-decoder
for tag in tags:
_subdecs[tag] = dec
try:
yield dec
finally:
for tag in tags:
_subdecs.pop(tag)


def decmsg(msg: Msg) -> Any:
msg = _root_dec.decode(msg)
tag_field = msg.header.msgtype
dec = _subdecs[tag_field]
return dec.decode(msg.payload)


def encmsg(
dialog_id: str | int,
payload: Any,
) -> Msg:

tag_field = None

plbytes = _root_enc.encode(payload)
if b'type' in plbytes:
assert isinstance(payload, Struct)
tag_field = type(payload).__name__
payload = Raw(plbytes)

msg = Msg(
Header(dialog_id, tag_field),
payload,
)
return _root_enc.encode(msg)