diff --git a/ltchiptool/soc/ambz2/flash.py b/ltchiptool/soc/ambz2/flash.py index dee27df..b0c0e54 100644 --- a/ltchiptool/soc/ambz2/flash.py +++ b/ltchiptool/soc/ambz2/flash.py @@ -107,13 +107,17 @@ def flash_write_raw( self.flash_connect() assert self.amb callback.attach(data, limit=length) - self.amb.memory_write( - offset=offset, - stream=data, - use_flash=True, - hash_check=verify, - ) - callback.detach(data) + try: + self.amb.memory_write( + offset=offset, + stream=data, + use_flash=True, + hash_check=verify, + ) + callback.detach(data) + except Exception as e: + callback.detach(data) + raise e def flash_write_uf2( self, diff --git a/ltchiptool/soc/ambz2/util/ambz2tool.py b/ltchiptool/soc/ambz2/util/ambz2tool.py index 386a8b2..2c91bc5 100644 --- a/ltchiptool/soc/ambz2/util/ambz2tool.py +++ b/ltchiptool/soc/ambz2/util/ambz2tool.py @@ -13,7 +13,7 @@ from xmodem import XMODEM from ltchiptool.util.intbin import align_down -from ltchiptool.util.logging import LoggingHandler, verbose +from ltchiptool.util.logging import LoggingHandler, stream, verbose from ltchiptool.util.misc import retry_catching, retry_generator _T_XmodemCB = Optional[Callable[[int, int, int], None]] @@ -34,7 +34,7 @@ class AmbZ2FlashSpeed(IntEnum): class AmbZ2Tool: - crc_speed_bps: int = 2000000 + crc_speed_bps: int = 1500000 prev_timeout_list: List[float] flash_mode: AmbZ2FlashMode = None flash_speed: AmbZ2FlashSpeed = AmbZ2FlashSpeed.SINGLE @@ -83,16 +83,16 @@ def command(self, cmd: str) -> None: def read(self, count: int = None) -> bytes: response = b"" end = time() + self.read_timeout - end_nb = time() + 0.01 # not before while time() < end: - read = self.s.read_all() + if count: + read = self.s.read(count - len(response)) + else: + read = self.s.read_all() response += read if count and len(response) >= count: break - if not response or time() <= end_nb: - continue - if not read: - break + if read: + end = time() + self.read_timeout if not response: raise TimeoutError(f"Timeout in read({count}) - no data received") @@ -100,7 +100,9 @@ def read(self, count: int = None) -> bytes: return response response = response[:count] if len(response) != count: - raise TimeoutError(f"Timeout in read({count}) - not enough data received") + raise TimeoutError( + f"Timeout in read({count}) - not enough data received ({len(response)})" + ) return response def readlines(self) -> Generator[str, None, None]: @@ -154,7 +156,7 @@ def ping(self) -> None: self.command("ping") resp = self.read(4) if resp != b"ping": - raise RuntimeError(f"incorrect ping response: {resp!r}") + raise RuntimeError(f"Incorrect ping response: {resp!r}") def disconnect(self) -> None: self.command("disc") @@ -175,13 +177,16 @@ def change_baudrate(self, baudrate: int) -> None: self.ping() self.command(f"ucfg {baudrate} 0 0") # change Serial port baudrate - debug("-- UART: Changing port baudrate") + stream("-- UART: Changing port baudrate") self.s.baudrate = baudrate # wait up to 1 second for OK response self.push_timeout(1.0) - resp = self.read() + try: + resp = self.read() + except TimeoutError: + raise RuntimeError("Timed out while changing baud rate") if resp != b"OK": - raise RuntimeError(f"baud rate change not OK: {resp!r}") + raise RuntimeError(f"Baud rate change not OK: {resp!r}") self.pop_timeout() # link again to make sure it still works diff --git a/ltchiptool/soc/interface.py b/ltchiptool/soc/interface.py index 56b00ce..14a0ec1 100644 --- a/ltchiptool/soc/interface.py +++ b/ltchiptool/soc/interface.py @@ -141,7 +141,7 @@ def flash_get_chip_info_string(self) -> str: def flash_get_guide(self) -> List[Union[str, list]]: """Get a short textual guide for putting the chip in download mode.""" - raise NotImplementedError() + return [] # Optional; do not fail here def flash_get_size(self) -> int: """Retrieve the flash size, in bytes.""" diff --git a/ltchiptool/util/streams.py b/ltchiptool/util/streams.py index 0772bb9..328d9c2 100644 --- a/ltchiptool/util/streams.py +++ b/ltchiptool/util/streams.py @@ -110,7 +110,7 @@ def __init__(self): self.buf = {"-> RX": "", "<- TX": ""} def _print(self, data: bytes, msg: str): - if all(c in self.ASCII for c in data): + if data and all(c in self.ASCII for c in data): data = data.decode().replace("\r", "") while "\n" in data: line, _, data = data.partition("\n") @@ -118,12 +118,14 @@ def _print(self, data: bytes, msg: str): self.buf[msg] = "" if line: stream(f"{msg}: '{line}'") - self.buf[msg] = data + self.buf[msg] += data return if self.buf[msg]: stream(f"{msg}: '{self.buf[msg]}'") self.buf[msg] = "" + if not data: + return if data.isascii(): stream(f"{msg}: {data[0:128]}") @@ -138,6 +140,7 @@ def on_after_read(self, data: bytes) -> Optional[bytes]: return None def on_before_write(self, data: bytes) -> Optional[bytes]: + self._print(b"", "-> RX") # print leftover bytes self._print(data, "<- TX") return None