From ff47002b94dcf7565079e3c351ca94fb5432184b Mon Sep 17 00:00:00 2001 From: "J.P. Hutchins" Date: Tue, 23 Jul 2024 19:43:16 -0700 Subject: [PATCH] fix: add smp_date field used to set the bytes if deserializing This is not ideal because it now exposes the field as a confusing argument during serialization. Generally, de/serialization is getting messy with some redundant operations. --- smp/message.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/smp/message.py b/smp/message.py index 6eca639..2084503 100644 --- a/smp/message.py +++ b/smp/message.py @@ -40,13 +40,14 @@ class _MessageBase(ABC, BaseModel): header: smpheader.Header = None # type: ignore version: smpheader.Version = smpheader.Version.V2 sequence: int = None # type: ignore + smp_data: bytes = None # type: ignore def __bytes__(self) -> bytes: - return self._bytes + return self.smp_data @property def BYTES(self) -> bytes: - return self._bytes + return self.smp_data @classmethod def loads(cls: Type[T], data: bytes) -> T: @@ -54,6 +55,7 @@ def loads(cls: Type[T], data: bytes) -> T: message = cls( header=smpheader.Header.loads(data[: smpheader.Header.SIZE]), **cast(dict, cbor2.loads(data[smpheader.Header.SIZE :])), + smp_data=data, ) if message.header is None: # pragma: no cover raise ValueError @@ -75,10 +77,11 @@ def load(cls: Type[T], header: smpheader.Header, data: dict) -> T: def model_post_init(self, _: None) -> None: data_bytes = cbor2.dumps( self.model_dump( - exclude_unset=True, exclude={'header', 'version', 'sequence'}, exclude_none=True + exclude_unset=True, + exclude={'header', 'version', 'sequence', 'smp_data'}, + exclude_none=True, ) ) - self._bytes: bytes if self.header is None: # create the header object.__setattr__( self, @@ -95,7 +98,7 @@ def model_post_init(self, _: None) -> None: ) object.__setattr__(self, 'sequence', self.header.sequence) else: # validate the header and update version & sequence - if self.header.length != len(data_bytes): + if self.smp_data is None and self.header.length != len(data_bytes): raise SMPMalformed( f"header.length {self.header.length} != len(data_bytes) {len(data_bytes)}" ) @@ -111,7 +114,8 @@ def model_post_init(self, _: None) -> None: "from the provided header." ) object.__setattr__(self, 'version', self.header.version) - self._bytes = self.header.BYTES + data_bytes + if self.smp_data is None: + object.__setattr__(self, 'smp_data', bytes(self.header) + data_bytes) class Request(_MessageBase, ABC):