From a035f302af62c873118e89b3801e759883935472 Mon Sep 17 00:00:00 2001 From: Scott Mabin Date: Mon, 11 Sep 2023 14:28:42 +0100 Subject: [PATCH] Remove usage of dyn, instead use an enum for runtime dispatch (#33) * Remove usage of dyn, instead use an enum for runtime dispatch * cfg rework, transport enum Remove much of the cfg blocks from varios places. Where possible, new concise cfgs have been introduced. Adds a transport to replace dyn. --- build.rs | 15 +++++++++ src/io.rs | 71 +++++++++++++++++++++++++++++++++++---- src/io/uart.rs | 6 ++-- src/io/usb_serial_jtag.rs | 4 ++- src/lib.rs | 37 +++++--------------- src/main.rs | 39 +++++++++++++++------ src/protocol.rs | 69 ++++++++++++++++++------------------- src/targets.rs | 25 ++++++++++++++ 8 files changed, 181 insertions(+), 85 deletions(-) diff --git a/build.rs b/build.rs index 10d6dd8..c0881a8 100644 --- a/build.rs +++ b/build.rs @@ -88,4 +88,19 @@ fn main() { println!("cargo:rerun-if-changed=ld/ld/esp32s3_rom.x"); println!("cargo:rustc-link-arg=-Tld/esp32s3_rom.x"); } + + emit_cfg(); +} + +fn emit_cfg() { + #[cfg(any( + feature = "esp32c3", + feature = "esp32s3", + feature = "esp32c6", + feature = "esp32h2" + ))] + println!("cargo:rustc-cfg=usb_device"); + + #[cfg(any(feature = "esp32s2", feature = "esp32s3"))] + println!("cargo:rustc-cfg=usb0"); } diff --git a/src/io.rs b/src/io.rs index 7e66f0c..7f66dc7 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,14 +1,73 @@ +use core::marker::PhantomData; + use heapless::Deque; +use crate::protocol::InputIO; + pub mod uart; -#[cfg(any( - feature = "esp32c3", - feature = "esp32s3", - feature = "esp32c6", - feature = "esp32h2" -))] +#[cfg(usb_device)] pub mod usb_serial_jtag; const RX_QUEUE_SIZE: usize = crate::targets::MAX_WRITE_BLOCK + 0x400; static mut RX_QUEUE: Deque = Deque::new(); + +trait UartMarker: InputIO {} +trait UsbSerialJtagMarker: InputIO {} +trait UsbOtgMarker: InputIO {} + +#[non_exhaustive] +pub enum Transport { + Uart(S), + #[cfg(usb_device)] + UsbSerialJtag(J), + #[cfg(usb0)] + UsbOtg(U), + #[doc(hidden)] + __Hidden(PhantomData, PhantomData), +} + +impl InputIO for Transport +where + S: UartMarker, + J: UsbSerialJtagMarker, + U: UsbOtgMarker, +{ + fn recv(&mut self) -> u8 { + match self { + Transport::Uart(s) => s.recv(), + #[cfg(usb_device)] + Transport::UsbSerialJtag(j) => j.recv(), + _ => todo!(), + } + } + + fn send(&mut self, data: &[u8]) { + match self { + Transport::Uart(s) => s.send(data), + #[cfg(usb_device)] + Transport::UsbSerialJtag(j) => j.send(data), + _ => todo!(), + } + } +} + +pub struct Noop; + +impl InputIO for Noop { + fn recv(&mut self) -> u8 { + todo!() + } + + fn send(&mut self, _data: &[u8]) { + todo!() + } +} + +impl UartMarker for Noop {} +impl UsbSerialJtagMarker for Noop {} +impl UsbOtgMarker for Noop {} + +impl UartMarker for &mut T {} +impl UsbSerialJtagMarker for &mut T {} +impl UsbOtgMarker for &mut T {} diff --git a/src/io/uart.rs b/src/io/uart.rs index 75053ad..8bb5d38 100644 --- a/src/io/uart.rs +++ b/src/io/uart.rs @@ -1,4 +1,4 @@ -use super::RX_QUEUE; +use super::{UartMarker, RX_QUEUE}; use crate::{ hal::{peripherals::UART0, prelude::*, uart::Instance, Uart}, protocol::InputIO, @@ -15,6 +15,8 @@ impl InputIO for Uart<'_, T> { } } +impl UartMarker for Uart<'_, T> {} + #[interrupt] fn UART0() { let uart = unsafe { &*UART0::ptr() }; @@ -30,7 +32,7 @@ fn UART0() { // the read _must_ be a word read so the hardware correctly detects the read and // pops the byte from the fifo cast the result to a u8, as only the // first byte contains the data - let data = unsafe { (uart.fifo.as_ptr() as *mut u32).offset(offset).read() } as u8; + let data = unsafe { uart.fifo.as_ptr().offset(offset).read() } as u8; unsafe { RX_QUEUE.push_back(data).unwrap() }; } diff --git a/src/io/usb_serial_jtag.rs b/src/io/usb_serial_jtag.rs index 5e7a6a7..32be16c 100644 --- a/src/io/usb_serial_jtag.rs +++ b/src/io/usb_serial_jtag.rs @@ -1,4 +1,4 @@ -use super::RX_QUEUE; +use super::{UsbSerialJtagMarker, RX_QUEUE}; use crate::{ hal::{ prelude::*, @@ -18,6 +18,8 @@ impl InputIO for UsbSerialJtag<'_> { } } +impl UsbSerialJtagMarker for UsbSerialJtag<'_> {} + #[interrupt] unsafe fn USB_DEVICE() { let usj = crate::hal::peripherals::USB_DEVICE::steal(); diff --git a/src/lib.rs b/src/lib.rs index adce3c0..29b7ac6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,9 +15,6 @@ pub use esp32h2_hal as hal; pub use esp32s2_hal as hal; #[cfg(feature = "esp32s3")] pub use esp32s3_hal as hal; -// Due to a bug in esp-hal this MUST be included in the root. -#[cfg(target_arch = "riscv32")] -pub use hal::interrupt; // Re-export the correct target based on which feature is active #[cfg(feature = "esp32")] pub use targets::Esp32 as target; @@ -44,18 +41,15 @@ pub mod targets; #[derive(Debug)] pub enum TransportMethod { Uart, - #[cfg(any( - feature = "esp32c3", - feature = "esp32s3", - feature = "esp32c6", - feature = "esp32h2" - ))] + #[cfg(usb_device)] UsbSerialJtag, - #[cfg(any(feature = "esp32s2", feature = "esp32s3"))] + #[cfg(usb0)] UsbOtg, } pub fn detect_transport() -> TransportMethod { + #[allow(unused)] + use targets::{EspUsbOtgId, EspUsbSerialJtagId}; #[repr(C)] struct Uart { baud_rate: u32, @@ -72,28 +66,13 @@ pub fn detect_transport() -> TransportMethod { extern "C" { fn esp_flasher_rom_get_uart() -> *const Uart; } - #[cfg(any(feature = "esp32c3", feature = "esp32c6", feature = "esp32h2"))] - const USB_SERIAL_JTAG: u8 = 3; - #[cfg(any(feature = "esp32s3"))] - const USB_SERIAL_JTAG: u8 = 4; - - #[cfg(feature = "esp32s3")] - const USB_OTG: u8 = 3; - #[cfg(feature = "esp32s2")] - const USB_OTG: u8 = 2; - let device = unsafe { esp_flasher_rom_get_uart() }; let num = unsafe { (*device).buff_uart_no }; match num { - #[cfg(any( - feature = "esp32c3", - feature = "esp32s3", - feature = "esp32c6", - feature = "esp32h2" - ))] - USB_SERIAL_JTAG => TransportMethod::UsbSerialJtag, - #[cfg(any(feature = "esp32s2", feature = "esp32s3"))] - USB_OTG => TransportMethod::UsbOtg, + #[cfg(usb_device)] + target::USB_SERIAL_JTAG_ID => TransportMethod::UsbSerialJtag, + #[cfg(usb0)] + target::USB_OTG_ID => TransportMethod::UsbOtg, _ => TransportMethod::Uart, } } diff --git a/src/main.rs b/src/main.rs index acf09d0..c3a4228 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,13 +8,35 @@ use flasher_stub::hal::uart::{ }; use flasher_stub::{ hal::{clock::ClockControl, interrupt, peripherals, prelude::*, Uart, IO}, - protocol::{InputIO, Stub}, + io::Noop, + protocol::Stub, targets, }; use static_cell::StaticCell; const MSG_BUFFER_SIZE: usize = targets::MAX_WRITE_BLOCK + 0x400; +// TODO this sucks, but default generic parameters are not used when inference +// fails, meaning that we _have_ to specifiy the types here Seems like work on this has stalled: https://github.com/rust-lang/rust/issues/27336, note that I tried the feature and it didn't work. +#[cfg(not(any(usb_device, usb0)))] +type Transport = + flasher_stub::io::Transport<&'static mut Uart<'static, crate::peripherals::UART0>, Noop, Noop>; +#[cfg(all(usb_device, not(usb0)))] +type Transport = flasher_stub::io::Transport< + &'static mut Uart<'static, crate::peripherals::UART0>, + &'static mut flasher_stub::hal::UsbSerialJtag<'static>, + Noop, +>; +#[cfg(all(not(usb_device), usb0))] +type Transport = + flasher_stub::io::Transport<&'static mut Uart<'static, crate::peripherals::UART0>, Noop, Noop>; // TODO replace Noop with usb type later +#[cfg(all(usb_device, usb0))] +type Transport = flasher_stub::io::Transport< + &'static mut Uart<'static, crate::peripherals::UART0>, + &'static mut flasher_stub::hal::UsbSerialJtag<'static>, + Noop, // TODO replace Noop with usb type later +>; + #[flasher_stub::hal::entry] fn main() -> ! { let peripherals = peripherals::Peripherals::take(); @@ -62,7 +84,7 @@ fn main() -> ! { let transport = flasher_stub::detect_transport(); flasher_stub::dprintln!("Stub init! Transport detected: {:?}", transport); - let transport: &'static mut dyn InputIO = match transport { + let transport = match transport { flasher_stub::TransportMethod::Uart => { let mut serial = Uart::new(peripherals.UART0, &mut system.peripheral_clock_control); @@ -76,14 +98,9 @@ fn main() -> ! { static mut TRANSPORT: StaticCell> = StaticCell::new(); - unsafe { TRANSPORT.init(serial) } + Transport::Uart(unsafe { TRANSPORT.init(serial) }) } - #[cfg(any( - feature = "esp32c3", - feature = "esp32s3", - feature = "esp32c6", - feature = "esp32h2" - ))] + #[cfg(usb_device)] flasher_stub::TransportMethod::UsbSerialJtag => { let mut usb_serial = flasher_stub::hal::UsbSerialJtag::new( peripherals.USB_DEVICE, @@ -98,9 +115,9 @@ fn main() -> ! { static mut TRANSPORT: StaticCell> = StaticCell::new(); - unsafe { TRANSPORT.init(usb_serial) } + Transport::UsbSerialJtag(unsafe { TRANSPORT.init(usb_serial) }) } - #[cfg(any(feature = "esp32s2", feature = "esp32s3"))] + #[cfg(usb0)] flasher_stub::TransportMethod::UsbOtg => unimplemented!(), }; diff --git a/src/protocol.rs b/src/protocol.rs index 8b8c6c6..88d9d23 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -20,8 +20,18 @@ pub trait InputIO { fn send(&mut self, data: &[u8]); } -pub struct Stub<'a> { - io: &'a mut (dyn InputIO + 'a), +impl InputIO for &mut T { + fn recv(&mut self) -> u8 { + (*self).recv() + } + + fn send(&mut self, data: &[u8]) { + (*self).send(data) + } +} + +pub struct Stub { + io: T, end_addr: u32, write_addr: u32, erase_addr: u32, @@ -30,20 +40,7 @@ pub struct Stub<'a> { decompressor: tinfl_decompressor, last_error: Option, in_flash_mode: bool, - #[cfg(feature = "esp32c3")] - target: crate::targets::Esp32c3, - #[cfg(feature = "esp32c6")] - target: crate::targets::Esp32c6, - #[cfg(feature = "esp32h2")] - target: crate::targets::Esp32h2, - #[cfg(feature = "esp32c2")] - target: crate::targets::Esp32c2, - #[cfg(feature = "esp32")] - target: crate::targets::Esp32, - #[cfg(feature = "esp32s3")] - target: crate::targets::Esp32s3, - #[cfg(feature = "esp32s2")] - target: crate::targets::Esp32s2, + target: crate::target, } fn slice_to_struct(slice: &[u8]) -> Result { @@ -63,8 +60,8 @@ fn u32_from_slice(slice: &[u8], index: usize) -> u32 { u32::from_le_bytes(slice[index..index + 4].try_into().unwrap()) } -impl<'a> Stub<'a> { - pub fn new(input_io: &'a mut dyn InputIO) -> Self { +impl Stub { + pub fn new(input_io: T) -> Self { let stub = Stub { io: input_io, write_addr: 0, @@ -84,19 +81,19 @@ impl<'a> Stub<'a> { fn send_response(&mut self, resp: &Response) { let resp_slice = unsafe { to_slice_u8(resp) }; - write_delimiter(self.io); - write_raw(self.io, &resp_slice[..RESPONSE_SIZE]); - write_raw(self.io, resp.data); - write_delimiter(self.io); + write_delimiter(&mut self.io); + write_raw(&mut self.io, &resp_slice[..RESPONSE_SIZE]); + write_raw(&mut self.io, resp.data); + write_delimiter(&mut self.io); } fn send_response_with_data(&mut self, resp: &Response, data: &[u8]) { let resp_slice = unsafe { to_slice_u8(resp) }; - write_delimiter(self.io); - write_raw(self.io, &resp_slice[..RESPONSE_SIZE - 2]); - write_raw(self.io, data); - write_raw(self.io, &resp_slice[RESPONSE_SIZE - 2..RESPONSE_SIZE]); - write_delimiter(self.io); + write_delimiter(&mut self.io); + write_raw(&mut self.io, &resp_slice[..RESPONSE_SIZE - 2]); + write_raw(&mut self.io, data); + write_raw(&mut self.io, &resp_slice[RESPONSE_SIZE - 2..RESPONSE_SIZE]); + write_delimiter(&mut self.io); } fn send_md5_response(&mut self, resp: &Response, md5: &[u8]) { @@ -109,7 +106,7 @@ impl<'a> Stub<'a> { pub fn send_greeting(&mut self) { let greeting = [b'O', b'H', b'A', b'I']; - write_packet(self.io, &greeting); + write_packet(&mut self.io, &greeting); } fn calculate_md5(&mut self, mut address: u32, mut size: u32) -> Result<[u8; 16], Error> { @@ -374,18 +371,18 @@ impl<'a> Stub<'a> { let len = min(params.packet_size, remaining); self.target .spi_flash_read(address, &mut buffer[..len as usize])?; - write_packet(self.io, &buffer[..len as usize]); + write_packet(&mut self.io, &buffer[..len as usize]); hasher.consume(&buffer[0..len as usize]); remaining -= len; address += len; sent += len; } - let resp = read_packet(self.io, &mut ack_buf); + let resp = read_packet(&mut self.io, &mut ack_buf); acked = u32_from_slice(resp, 0); } let md5: [u8; 16] = hasher.compute().into(); - write_packet(self.io, &md5); + write_packet(&mut self.io, &md5); Ok(()) } @@ -497,14 +494,14 @@ impl<'a> Stub<'a> { } pub fn read_command<'c>(&mut self, buffer: &'c mut [u8]) -> &'c [u8] { - read_packet(self.io, buffer) + read_packet(&mut self.io, buffer) } } mod slip { use super::*; - pub fn read_packet<'c>(io: &mut dyn InputIO, packet: &'c mut [u8]) -> &'c [u8] { + pub fn read_packet<'c, T: InputIO>(io: &mut T, packet: &'c mut [u8]) -> &'c [u8] { while io.recv() != 0xC0 {} // Replace: 0xDB 0xDC -> 0xC0 and 0xDB 0xDD -> 0xDB @@ -525,7 +522,7 @@ mod slip { &packet[..i] } - pub fn write_raw(io: &mut dyn InputIO, data: &[u8]) { + pub fn write_raw(io: &mut T, data: &[u8]) { for byte in data { match byte { 0xC0 => io.send(&[0xDB, 0xDC]), @@ -535,13 +532,13 @@ mod slip { } } - pub fn write_packet(io: &mut dyn InputIO, data: &[u8]) { + pub fn write_packet(io: &mut T, data: &[u8]) { write_delimiter(io); write_raw(io, data); write_delimiter(io); } - pub fn write_delimiter(io: &mut dyn InputIO) { + pub fn write_delimiter(io: &mut T) { io.send(&[0xC0]); } } diff --git a/src/targets.rs b/src/targets.rs index 3fd65c8..23aa555 100644 --- a/src/targets.rs +++ b/src/targets.rs @@ -399,3 +399,28 @@ impl EspCommon for Esp32c3 {} impl EspCommon for Esp32c2 {} impl EspCommon for Esp32c6 {} impl EspCommon for Esp32h2 {} + +pub trait EspUsbSerialJtagId { + /// The ID returned for USB_SERIAL_JTAG from `esp_flasher_rom_get_uart` + const USB_SERIAL_JTAG_ID: u8 = 3; // default for most chips is 3 +} + +impl EspUsbSerialJtagId for Esp32c3 {} +impl EspUsbSerialJtagId for Esp32c6 {} +impl EspUsbSerialJtagId for Esp32h2 {} +impl EspUsbSerialJtagId for Esp32s3 { + const USB_SERIAL_JTAG_ID: u8 = 4; +} + +pub trait EspUsbOtgId { + /// The ID returned for USB_OTG from `esp_flasher_rom_get_uart` + const USB_OTG_ID: u8; +} + +impl EspUsbOtgId for Esp32s2 { + const USB_OTG_ID: u8 = 2; +} + +impl EspUsbOtgId for Esp32s3 { + const USB_OTG_ID: u8 = 3; +}