diff --git a/ltchiptool/soc/ambz2/util/ambz2tool.py b/ltchiptool/soc/ambz2/util/ambz2tool.py index 954fd26..a5f56be 100644 --- a/ltchiptool/soc/ambz2/util/ambz2tool.py +++ b/ltchiptool/soc/ambz2/util/ambz2tool.py @@ -6,18 +6,49 @@ from logging import debug, warning from math import ceil from time import time -from typing import IO, Callable, Generator, List, Optional +from typing import IO, Callable, Generator, List, Optional, Tuple import click +from hexdump import hexdump, restore +from serial import Serial from xmodem import XMODEM -from ltchiptool.util.intbin import align_down +from ltchiptool.util.cli import DevicePortParamType +from ltchiptool.util.intbin import align_down, biniter, gen2bytes, letoint, pad_data from ltchiptool.util.logging import LoggingHandler, verbose from ltchiptool.util.misc import retry_catching, retry_generator from ltchiptool.util.serialtool import SerialToolBase _T_XmodemCB = Optional[Callable[[int, int, int], None]] +AMBZ2_FALLBACK_CMD = b"Rtk8710C\n" +AMBZ2_FALLBACK_RESP = [ + b"\r\n$8710c>" * 2, + b"Rtk8710C\r\nCommand NOT found.\r\n$8710c>", +] +USED_COMMANDS = [ + "ping", + "disc", + "ucfg", + "DW", + "DB", + "EW", + "EB", + "WDTRST", + "hashq", + "fwd", + "fwdram", +] + +AMBZ2_CODE_ADDR = 0x10037000 +AMBZ2_DATA_ADDR = 0x10038000 +AMBZ2_EFUSE_PHYSICAL_SIZE = 512 +AMBZ2_EFUSE_LOGICAL_SIZE = 512 + +AMBZ2_CHIP_TYPE = { + 0xFE: "RTL87x0CF", +} + class AmbZ2FlashMode(IntEnum): RTL8720CX_CM = 0 # PIN_A7_A12 @@ -38,6 +69,8 @@ class AmbZ2Tool(SerialToolBase): flash_mode: AmbZ2FlashMode = None flash_speed: AmbZ2FlashSpeed = AmbZ2FlashSpeed.SINGLE flash_hash_offset: int = None + in_fallback_mode: bool = False + boot_cmd: Tuple[int, str] = None def __init__( self, @@ -65,18 +98,26 @@ def flash_cfg(self) -> str: def command(self, cmd: str) -> None: self.flush() - self.s.write(cmd.encode() + b"\n") + cmd = cmd.encode() + self.s.write(cmd + b"\n") + if self.in_fallback_mode: + self.s.read(len(cmd) + 2) def ping(self) -> None: self.command("ping") resp = self.read(4) if resp != b"ping": raise RuntimeError(f"Incorrect ping response: {resp!r}") + resp = self.s.read_all() + if b"$8710c" in resp: + raise RuntimeError(f"Got fallback mode ping: {resp!r}") def disconnect(self) -> None: self.command("disc") def link(self) -> None: + # try linking in fallback mode - 'ping' before that would break it + self.link_fallback() end = time() + self.link_timeout while time() < end: try: @@ -86,6 +127,26 @@ def link(self) -> None: pass raise TimeoutError("Timeout while linking") + def link_fallback(self) -> None: + self.flush() + self.write(AMBZ2_FALLBACK_CMD) + self.push_timeout(0.1) + try: + response = self.read() + if response not in AMBZ2_FALLBACK_RESP: + return + except TimeoutError: + return + finally: + self.pop_timeout() + debug(f"Found fallback mode with response: {response}") + self.in_fallback_mode = True + # check ROM version + chip_ver = (self.register_read(0x4000_01F0) >> 4) & 0xF + # jump to download mode + self.memory_boot(0x0 if chip_ver > 2 else 0x1443C) + self.in_fallback_mode = False + def change_baudrate(self, baudrate: int) -> None: if self.s.baudrate == baudrate: return @@ -125,7 +186,7 @@ def dump_words(self, start: int, count: int) -> Generator[List[int], None, None] line = line.split() addr = int(line[0].rstrip(":"), 16) if addr != start + read_count: - raise ValueError("Got invalid read address") + raise ValueError(f"Got invalid read address: {line}") chunk = list() for i, value in enumerate(line[1 : 1 + 4]): @@ -184,6 +245,27 @@ def register_write(self, address: int, value: int) -> None: self.command(f"EW {address:X} {value:X}") next(self.readlines()) + def register_read_bytes(self, address: int, length: int) -> bytes: + start = align_down(address, 4) + return gen2bytes(self.dump_bytes(start, length))[0:length] + + def register_write_bytes(self, address: int, value: bytes) -> None: + start = align_down(address, 4) + value = pad_data(value, 4, 0x00) + words = [] + for word in biniter(value, 4): + words.append(f"{letoint(word):X}") + # 'EW' command can theoretically write at most 8 words, + # but it seems to cut the command off at around 80 bytes + for i in range(0, len(words), 7): + chunk = words[i : i + 7] + command = f"EW {start + i * 4:X} " + command += " ".join(chunk) + self.command(command) + lines = self.readlines() + for _ in chunk: + next(lines) + def sw_reset(self) -> None: self.command("WDTRST") @@ -428,12 +510,104 @@ def memory_write( f"{hash_expected.hex()}, calculated: {hash_final.hex()}" ) + def memory_boot( + self, + address: int, + force_find: bool = False, + ) -> None: + address |= 1 + if self.boot_cmd is None or force_find: + # find ROM console command array + cmd_array = self.register_read(0x1002F050 + 4) + cmd_size = 4 * 3 + # try all commands to find an unused one + for cmd_ptr in range(cmd_array, cmd_array + 8 * cmd_size, cmd_size): + # read command name pointer + name_ptr = self.register_read(cmd_ptr + 0) + if name_ptr == 0: + break + # read command name + cmd_name = b"".join(self.dump_bytes(name_ptr, 16)) + cmd_name = cmd_name.partition(b"\x00")[0] + if not cmd_name.isascii(): + warning(f"Non-ASCII command string @ 0x{name_ptr:X}: {cmd_name}") + continue + cmd_name = cmd_name.decode() + if cmd_name in USED_COMMANDS: + continue + func_ptr = cmd_ptr + 4 + self.boot_cmd = func_ptr, cmd_name + if self.boot_cmd is None: + raise RuntimeError("No unused ROM command found, cannot boot from SRAM") + + func_ptr, cmd_name = self.boot_cmd + # write new command handler address + self.register_write(func_ptr, address) + debug(f"Jumping to 0x{address:X} with command '{cmd_name}'") + # execute command to jump to the function + self.command(cmd_name) + @click.command( help="AmebaZ2 flashing tool", ) -def cli(): - raise NotImplementedError() +@click.option( + "-d", + "--device", + help="Target device port (default: auto detect)", + type=DevicePortParamType(), + default=(), +) +def cli(device: str): + s = Serial(device, 115200) + s.timeout = 0.01 + + while True: + cmd = input("> ") + + if cmd == "m": + try: + while True: + read = s.read_all() + if read: + print(read.decode(errors="replace"), end="") + except KeyboardInterrupt: + continue + + s.write(cmd.encode()) + s.write(b"\r\n") + response = b"" + start = time() + + if cmd.startswith("ucfg"): + s.close() + baud = int(cmd.split(" ")[1]) + s = Serial(device, baud) + + if cmd.startswith("DB"): + f = open(cmd + ".bin", "wb") + while True: + try: + read = s.read_all() + if read: + print(read.decode(), end="") + response += read + while b"\n" in response: + line, _, response = response.partition(b"\n") + line = line.decode() + line = line.strip() + if line and "[Addr]" not in line: + f.write(restore(line)) + except KeyboardInterrupt: + break + f.close() + continue + + while True: + response += s.read_all() + if time() > start + 0.5: + break + hexdump(response) if __name__ == "__main__":