Skip to content

Commit

Permalink
[ambz] Replace rtltool usage with ambztool
Browse files Browse the repository at this point in the history
  • Loading branch information
kuba2k2 committed Nov 2, 2023
1 parent c0c9e5c commit c9686eb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 52 deletions.
122 changes: 75 additions & 47 deletions ltchiptool/soc/ambz/flash.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# Copyright (c) Kuba Szczodrzyński 2022-07-29.

from abc import ABC
from io import BytesIO
from logging import debug, warning
from time import sleep
from typing import IO, Generator, List, Optional, Union

from ltchiptool import SocInterface
from ltchiptool.util.flash import FlashConnection
from ltchiptool.util.intbin import gen2bytes, inttole32, letoint
from ltchiptool.util.intbin import gen2bytes, letoint
from ltchiptool.util.streams import ProgressCallback
from uf2tool import OTAScheme, UploadContext

from .util.rtltool import CAN, RTL_ROM_BAUD, RTLXMD
from .util.ambzcode import AmbZCode
from .util.ambztool import (
AMBZ_CHIP_TYPE,
AMBZ_FLASH_ADDRESS,
AMBZ_ROM_BAUDRATE,
AmbZTool,
)

AMEBAZ_GUIDE = [
"Connect UART2 of the Realtek chip to the USB-TTL adapter:",
Expand All @@ -37,38 +43,39 @@

# noinspection PyProtectedMember
class AmebaZFlash(SocInterface, ABC):
rtl: Optional[RTLXMD] = None
is_can_sent: bool = False
amb: Optional[AmbZTool] = None
chip_info: bytes = None

def flash_set_connection(self, connection: FlashConnection) -> None:
if self.conn:
self.flash_disconnect()
self.conn = connection
self.conn.fill_baudrate(RTL_ROM_BAUD)
# use 460800 max. as default, since most cheap adapters can't go faster anyway
self.conn.fill_baudrate(460800, link_baudrate=AMBZ_ROM_BAUDRATE)

def flash_build_protocol(self, force: bool = False) -> None:
if not force and self.rtl:
if not force and self.amb:
return
self.flash_disconnect()
self.rtl = RTLXMD(
self.amb = AmbZTool(
port=self.conn.port,
baud=self.conn.link_baudrate,
timeout=0.07,
baudrate=self.conn.link_baudrate,
read_timeout=0.2,
)
self.flash_change_timeout(self.conn.timeout, self.conn.link_timeout)

def flash_change_timeout(self, timeout: float = 0.0, link_timeout: float = 0.0):
self.flash_build_protocol()
if timeout:
self.rtl._port.timeout = timeout
self.amb.read_timeout = timeout
self.conn.timeout = timeout
if link_timeout:
self.rtl.sync_timeout = link_timeout
self.amb.link_timeout = link_timeout
self.conn.link_timeout = link_timeout

def flash_sw_reset(self) -> None:
self.flash_build_protocol()
port = self.rtl._port
port = self.amb.s
prev_baudrate = port.baudrate
port.baudrate = 115200
sleep(0.1)
Expand All @@ -78,38 +85,52 @@ def flash_sw_reset(self) -> None:
sleep(0.5)
port.baudrate = prev_baudrate

def flash_hw_reset(self) -> None:
self.flash_build_protocol()
self.rtl.connect()

def flash_connect(self) -> None:
if self.rtl and self.conn.linked:
if self.amb and self.conn.linked:
return
self.flash_build_protocol()
if not self.is_can_sent:
# try to exit interrupted write operations
# sending 'CAN' exits the download mode, unless it's invoked via hardware
self.rtl._port.write(CAN)
self.is_can_sent = True
if not self.rtl.sync():
raise TimeoutError(f"Failed to connect on port {self.conn.port}")
assert self.amb
self.amb.link()
self.amb.change_baudrate(self.conn.baudrate)
self.conn.linked = True

def flash_disconnect(self) -> None:
if self.rtl:
self.rtl._port.close()
self.rtl._port = None
self.rtl = None
if self.amb:
try:
self.amb.link(disconnect=True)
except TimeoutError:
pass
self.amb.close()
self.amb = None
if self.conn:
self.conn.linked = False

def _read_chip_info(self) -> None:
self.flash_connect()
assert self.amb
self.chip_info = self.amb.ram_boot_read(
AmbZCode.read_chip_id(offset=0)
+ AmbZCode.read_flash_id(offset=1)
+ AmbZCode.print_data(length=4)
)
debug(f"Received chip info: {self.chip_info.hex()}")

def flash_get_chip_info_string(self) -> str:
return "Realtek RTL87xxB"
if not self.chip_info:
self._read_chip_info()
chip_id = self.chip_info[0]
return AMBZ_CHIP_TYPE.get(chip_id, f"Unknown 0x{chip_id:02X}")

def flash_get_guide(self) -> List[Union[str, list]]:
return AMEBAZ_GUIDE

def flash_get_size(self) -> int:
if not self.chip_info:
self._read_chip_info()
size_id = self.chip_info[3]
if 0x14 <= size_id <= 0x19:
return 1 << size_id
warning(f"Couldn't process flash ID: got {self.chip_info!r}")
return 0x200000

def flash_get_rom_size(self) -> int:
Expand All @@ -126,10 +147,13 @@ def flash_read_raw(
if use_rom:
self.flash_get_rom_size()
self.flash_connect()
gen = self.rtl.ReadBlockFlashGenerator(offset, length)
success = yield from callback.update_with(gen)
if not success:
raise ValueError(f"Failed to read from 0x{offset:X}")
assert self.amb
gen = self.amb.flash_read(
offset=offset,
length=length,
hash_check=verify,
)
yield from callback.update_with(gen)

def flash_write_raw(
self,
Expand All @@ -140,12 +164,17 @@ def flash_write_raw(
callback: ProgressCallback = ProgressCallback(),
) -> None:
self.flash_connect()
offset |= 0x8000000
callback.attach(data)
success = self.rtl.WriteBlockFlash(data, offset, length)
callback.detach(data)
if not success:
raise ValueError(f"Failed to write to 0x{offset:X}")
assert self.amb
callback.attach(data, limit=length)
try:
self.amb.memory_write(
address=AMBZ_FLASH_ADDRESS | offset,
stream=data,
)
callback.detach(data)
except Exception as e:
callback.detach(data)
raise e

def flash_write_uf2(
self,
Expand All @@ -156,14 +185,16 @@ def flash_write_uf2(
# read system data to get active OTA index
callback.on_message("Checking OTA index...")
system = gen2bytes(self.flash_read_raw(0x9000, 256))
if len(system) != 256:
if len(system) < 256:
raise ValueError(
f"Length invalid while reading from 0x9000 - {len(system)}"
)

# read OTA switch value
ota_switch = f"{letoint(system[4:8]):032b}"
# count 0-bits
ota_idx = 1 + (ota_switch.count("0") % 2)

# validate OTA2 address in system data
if ota_idx == 2:
ota2_addr = letoint(system[0:4]) & 0xFFFFFF
Expand All @@ -178,18 +209,15 @@ def flash_write_uf2(
parts = ctx.collect_data(
OTAScheme.FLASHER_DUAL_1 if ota_idx == 1 else OTAScheme.FLASHER_DUAL_2
)
callback.on_total(sum(len(part.getvalue()) for part in parts.values()) + 4)
callback.on_total(sum(len(part.getvalue()) for part in parts.values()))

callback.on_message(f"OTA {ota_idx}")
# write blocks to flash
for offset, data in parts.items():
length = len(data.getvalue())
callback.on_message(f"OTA {ota_idx} (0x{offset:06X})")
data.seek(0)
callback.on_message(f"OTA {ota_idx} (0x{offset:06X})")
self.flash_write_raw(offset, length, data, verify, callback)

callback.on_message("Booting firmware")
# [0x10002000] = 0x00005405
stream = BytesIO(inttole32(0x00005405))
self.rtl.WriteBlockSRAM(stream, 0x10002000, 4)
callback.on_update(4)
self.amb.ram_boot(address=0x00005405)
3 changes: 1 addition & 2 deletions ltchiptool/soc/ambz2/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def flash_write_uf2(

# collect continuous blocks of data
parts = ctx.collect_data(OTAScheme.FLASHER_DUAL_1)
callback.on_total(sum(len(part.getvalue()) for part in parts.values()) + 4)
callback.on_total(sum(len(part.getvalue()) for part in parts.values()))

# write blocks to flash
for offset, data in parts.items():
Expand All @@ -150,4 +150,3 @@ def flash_write_uf2(

callback.on_message("Booting firmware")
self.amb.disconnect()
callback.on_update(4)
8 changes: 5 additions & 3 deletions ltchiptool/util/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ class FlashConnection:
link_timeout: float = 20.0
linked: bool = False

def fill_baudrate(self, baudrate: int) -> None:
self.link_baudrate = self.link_baudrate or baudrate
self.baudrate = self.baudrate or self.link_baudrate or baudrate
def fill_baudrate(self, baudrate: int, link_baudrate: int = None) -> None:
if link_baudrate is None:
link_baudrate = baudrate
self.link_baudrate = self.link_baudrate or link_baudrate
self.baudrate = self.baudrate or baudrate or self.link_baudrate


def format_flash_guide(soc) -> List[str]:
Expand Down

0 comments on commit c9686eb

Please sign in to comment.