From 7da4444a7e8680bdb12b313bacd325e7243c52e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Quentin?= Date: Wed, 13 Nov 2024 12:29:36 +0100 Subject: [PATCH] Fail RMT one-shot transactions if end-marker is missing (#2463) * Fail RMT one-shot transactions if end-marker is missing * CHANGELOG.md * Add test * Fix * Fix * RMT: use u32, turn PulseCode into a convenience trait * Clippy * Adapt test --- esp-hal/CHANGELOG.md | 1 + esp-hal/MIGRATING-0.21.md | 28 ++++++ esp-hal/src/rmt.rs | 145 +++++++++++++++-------------- examples/src/bin/embassy_rmt_rx.rs | 27 +++--- examples/src/bin/embassy_rmt_tx.rs | 16 +--- examples/src/bin/rmt_rx.rs | 30 +++--- examples/src/bin/rmt_tx.rs | 19 +--- hil-test/tests/rmt.rs | 66 ++++++++----- 8 files changed, 178 insertions(+), 154 deletions(-) diff --git a/esp-hal/CHANGELOG.md b/esp-hal/CHANGELOG.md index f57a9e67bbd..2bba6de8d04 100644 --- a/esp-hal/CHANGELOG.md +++ b/esp-hal/CHANGELOG.md @@ -62,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `slave::Spi` constructors no longer take pins (#2485) - The `I2c` master driver has been moved from `esp_hal::i2c` to `esp_hal::i2c::master`. (#2476) - `I2c` SCL timeout is now defined in bus clock cycles. (#2477) +- Trying to send a single-shot RMT transmission will result in an error now, `RMT` deals with `u32` now, `PulseCode` is a convenience trait now (#2463) ### Fixed diff --git a/esp-hal/MIGRATING-0.21.md b/esp-hal/MIGRATING-0.21.md index 9304bddd0b6..ca6dfba8d06 100644 --- a/esp-hal/MIGRATING-0.21.md +++ b/esp-hal/MIGRATING-0.21.md @@ -368,3 +368,31 @@ If you were using an 16-bit bus, you don't need to change anything, `set_byte_or If you were sharing the bus between an 8-bit and 16-bit device, you will have to call the corresponding method when you switch between devices. Be sure to read the documentation of the new methods. + + +## `rmt::Channel::transmit` now returns `Result`, `PulseCode` is now `u32` + +When trying to send a one-shot transmission will fail if it doesn't end with an end-marker. + +```diff +- let mut data = [PulseCode { +- level1: true, +- length1: 200, +- level2: false, +- length2: 50, +- }; 20]; +- +- data[data.len() - 2] = PulseCode { +- level1: true, +- length1: 3000, +- level2: false, +- length2: 500, +- }; +- data[data.len() - 1] = PulseCode::default(); ++ let mut data = [PulseCode::new(true, 200, false, 50); 20]; ++ data[data.len() - 2] = PulseCode::new(true, 3000, false, 500); ++ data[data.len() - 1] = PulseCode::empty(); + +- let transaction = channel.transmit(&data); ++ let transaction = channel.transmit(&data).unwrap(); +``` diff --git a/esp-hal/src/rmt.rs b/esp-hal/src/rmt.rs index ef6f4b1063a..dc2d4e916e3 100644 --- a/esp-hal/src/rmt.rs +++ b/esp-hal/src/rmt.rs @@ -112,61 +112,62 @@ pub enum Error { InvalidArgument, /// An error occurred during transmission TransmissionError, + /// No transmission end marker found + EndMarkerMissing, } -/// Convenience representation of a pulse code entry. -/// -/// Allows for the assignment of two levels and their lengths -#[derive(Clone, Copy, Debug, Default, PartialEq)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct PulseCode { +/// Convenience trait to work with pulse codes. +pub trait PulseCode: crate::private::Sealed { + /// Create a new instance + fn new(level1: bool, length1: u16, level2: bool, length2: u16) -> Self; + + /// Create a new empty instance + fn empty() -> Self; + + /// Set all levels and lengths to 0 + fn reset(&mut self); + /// Logical output level in the first pulse code interval - pub level1: bool, + fn level1(&self) -> bool; + /// Length of the first pulse code interval (in clock cycles) - pub length1: u16, + fn length1(&self) -> u16; + /// Logical output level in the second pulse code interval - pub level2: bool, + fn level2(&self) -> bool; + /// Length of the second pulse code interval (in clock cycles) - pub length2: u16, + fn length2(&self) -> u16; } -impl From for PulseCode { - fn from(value: u32) -> Self { - Self { - level1: value & (1 << 15) != 0, - length1: (value & 0b111_1111_1111_1111) as u16, - level2: value & (1 << 31) != 0, - length2: ((value >> 16) & 0b111_1111_1111_1111) as u16, - } +impl PulseCode for u32 { + fn new(level1: bool, length1: u16, level2: bool, length2: u16) -> Self { + (((level1 as u32) << 15) | length1 as u32 & 0b111_1111_1111_1111) + | (((level2 as u32) << 15) | length2 as u32 & 0b111_1111_1111_1111) << 16 } -} -/// Convert a pulse code structure into a u32 value that can be written -/// into the data registers -impl From for u32 { - #[inline(always)] - fn from(p: PulseCode) -> u32 { - // The length1 value resides in bits [14:0] - let mut entry: u32 = p.length1 as u32; - - // If level1 is high, set bit 15, otherwise clear it - if p.level1 { - entry |= 1 << 15; - } else { - entry &= !(1 << 15); - } + fn empty() -> Self { + 0 + } - // If level2 is high, set bit 31, otherwise clear it - if p.level2 { - entry |= 1 << 31; - } else { - entry &= !(1 << 31); - } + fn reset(&mut self) { + *self = 0 + } - // The length2 value resides in bits [30:16] - entry |= (p.length2 as u32) << 16; + fn level1(&self) -> bool { + self & (1 << 15) != 0 + } + + fn length1(&self) -> u16 { + (self & 0b111_1111_1111_1111) as u16 + } + + fn level2(&self) -> bool { + self & (1 << 31) != 0 + } - entry + fn length2(&self) -> u16 { + ((self >> 16) & 0b111_1111_1111_1111) as u16 } } @@ -423,16 +424,16 @@ where } /// An in-progress transaction for a single shot TX transaction. -pub struct SingleShotTxTransaction<'a, C, T: Into + Copy> +pub struct SingleShotTxTransaction<'a, C> where C: TxChannel, { channel: C, index: usize, - data: &'a [T], + data: &'a [u32], } -impl + Copy> SingleShotTxTransaction<'_, C, T> +impl SingleShotTxTransaction<'_, C> where C: TxChannel, { @@ -466,7 +467,7 @@ where .enumerate() { unsafe { - ptr.add(idx).write_volatile((*entry).into()); + ptr.add(idx).write_volatile(*entry); } } @@ -982,26 +983,23 @@ pub trait TxChannel: TxChannelInternal { /// This returns a [`SingleShotTxTransaction`] which can be used to wait for /// the transaction to complete and get back the channel for further /// use. - fn transmit + Copy>(self, data: &[T]) -> SingleShotTxTransaction<'_, Self, T> + fn transmit(self, data: &[u32]) -> Result, Error> where Self: Sized, { - let index = Self::send_raw(data, false, 0); - SingleShotTxTransaction { + let index = Self::send_raw(data, false, 0)?; + Ok(SingleShotTxTransaction { channel: self, index, data, - } + }) } /// Start transmitting the given pulse code continuously. /// This returns a [`ContinuousTxTransaction`] which can be used to stop the /// ongoing transmission and get back the channel for further use. /// The length of sequence cannot exceed the size of the allocated RMT RAM. - fn transmit_continuously + Copy>( - self, - data: &[T], - ) -> Result, Error> + fn transmit_continuously(self, data: &[u32]) -> Result, Error> where Self: Sized, { @@ -1011,10 +1009,10 @@ pub trait TxChannel: TxChannelInternal { /// Like [`Self::transmit_continuously`] but also sets a loop count. /// [`ContinuousTxTransaction`] can be used to check if the loop count is /// reached. - fn transmit_continuously_with_loopcount + Copy>( + fn transmit_continuously_with_loopcount( self, loopcount: u16, - data: &[T], + data: &[u32], ) -> Result, Error> where Self: Sized, @@ -1023,21 +1021,21 @@ pub trait TxChannel: TxChannelInternal { return Err(Error::Overflow); } - let _index = Self::send_raw(data, true, loopcount); + let _index = Self::send_raw(data, true, loopcount)?; Ok(ContinuousTxTransaction { channel: self }) } } /// RX transaction instance -pub struct RxTransaction<'a, C, T: From + Copy> +pub struct RxTransaction<'a, C> where C: RxChannel, { channel: C, - data: &'a mut [T], + data: &'a mut [u32], } -impl + Copy> RxTransaction<'_, C, T> +impl RxTransaction<'_, C> where C: RxChannel, { @@ -1062,7 +1060,7 @@ where as *mut u32; let len = self.data.len(); for (idx, entry) in self.data.iter_mut().take(len).enumerate() { - *entry = unsafe { ptr.add(idx).read_volatile().into() }; + *entry = unsafe { ptr.add(idx).read_volatile() }; } Ok(self.channel) @@ -1075,10 +1073,7 @@ pub trait RxChannel: RxChannelInternal { /// This returns a [RxTransaction] which can be used to wait for receive to /// complete and get back the channel for further use. /// The length of the received data cannot exceed the allocated RMT RAM. - fn receive + Copy>( - self, - data: &mut [T], - ) -> Result, Error> + fn receive(self, data: &mut [u32]) -> Result, Error> where Self: Sized, { @@ -1143,7 +1138,7 @@ pub trait TxChannelAsync: TxChannelInternal { /// Start transmitting the given pulse code sequence. /// The length of sequence cannot exceed the size of the allocated RMT /// RAM. - async fn transmit<'a, T: Into + Copy>(&mut self, data: &'a [T]) -> Result<(), Error> + async fn transmit<'a>(&mut self, data: &'a [u32]) -> Result<(), Error> where Self: Sized, { @@ -1154,7 +1149,7 @@ pub trait TxChannelAsync: TxChannelInternal { Self::clear_interrupts(); Self::listen_interrupt(Event::End); Self::listen_interrupt(Event::Error); - Self::send_raw(data, false, 0); + Self::send_raw(data, false, 0)?; RmtTxFuture::new(self).await; @@ -1402,9 +1397,17 @@ where fn is_loopcount_interrupt_set() -> bool; - fn send_raw + Copy>(data: &[T], continuous: bool, repeat: u16) -> usize { + fn send_raw(data: &[u32], continuous: bool, repeat: u16) -> Result { Self::clear_interrupts(); + if let Some(last) = data.last() { + if !continuous && last.length2() != 0 && last.length1() != 0 { + return Err(Error::EndMarkerMissing); + } + } else { + return Err(Error::InvalidArgument); + } + let ptr = (constants::RMT_RAM_START + Self::CHANNEL as usize * constants::RMT_CHANNEL_RAM_SIZE * 4) as *mut u32; @@ -1414,7 +1417,7 @@ where .enumerate() { unsafe { - ptr.add(idx).write_volatile((*entry).into()); + ptr.add(idx).write_volatile(*entry); } } @@ -1428,9 +1431,9 @@ where Self::update(); if data.len() >= constants::RMT_CHANNEL_RAM_SIZE { - constants::RMT_CHANNEL_RAM_SIZE + Ok(constants::RMT_CHANNEL_RAM_SIZE) } else { - data.len() + Ok(data.len()) } } diff --git a/examples/src/bin/embassy_rmt_rx.rs b/examples/src/bin/embassy_rmt_rx.rs index 08c5f8bc28d..81faf25cd11 100644 --- a/examples/src/bin/embassy_rmt_rx.rs +++ b/examples/src/bin/embassy_rmt_rx.rs @@ -73,46 +73,41 @@ async fn main(spawner: Spawner) { .spawn(signal_task(Output::new(peripherals.GPIO5, Level::Low))) .unwrap(); - let mut data = [PulseCode { - level1: true, - length1: 1, - level2: false, - length2: 1, - }; 48]; + let mut data: [u32; 48] = [PulseCode::empty(); 48]; loop { println!("receive"); channel.receive(&mut data).await.unwrap(); let mut total = 0usize; for entry in &data[..data.len()] { - if entry.length1 == 0 { + if entry.length1() == 0 { break; } - total += entry.length1 as usize; + total += entry.length1() as usize; - if entry.length2 == 0 { + if entry.length2() == 0 { break; } - total += entry.length2 as usize; + total += entry.length2() as usize; } for entry in &data[..data.len()] { - if entry.length1 == 0 { + if entry.length1() == 0 { break; } - let count = WIDTH / (total / entry.length1 as usize); - let c = if entry.level1 { '-' } else { '_' }; + let count = WIDTH / (total / entry.length1() as usize); + let c = if entry.level1() { '-' } else { '_' }; for _ in 0..count + 1 { print!("{}", c); } - if entry.length2 == 0 { + if entry.length2() == 0 { break; } - let count = WIDTH / (total / entry.length2 as usize); - let c = if entry.level2 { '-' } else { '_' }; + let count = WIDTH / (total / entry.length2() as usize); + let c = if entry.level2() { '-' } else { '_' }; for _ in 0..count + 1 { print!("{}", c); } diff --git a/examples/src/bin/embassy_rmt_tx.rs b/examples/src/bin/embassy_rmt_tx.rs index 5943c453700..20e14cfa574 100644 --- a/examples/src/bin/embassy_rmt_tx.rs +++ b/examples/src/bin/embassy_rmt_tx.rs @@ -50,20 +50,10 @@ async fn main(_spawner: Spawner) { ) .unwrap(); - let mut data = [PulseCode { - level1: true, - length1: 200, - level2: false, - length2: 50, - }; 20]; + let mut data = [PulseCode::new(true, 200, false, 50); 20]; - data[data.len() - 2] = PulseCode { - level1: true, - length1: 3000, - level2: false, - length2: 500, - }; - data[data.len() - 1] = PulseCode::default(); + data[data.len() - 2] = PulseCode::new(true, 3000, false, 500); + data[data.len() - 1] = PulseCode::empty(); loop { println!("transmit"); diff --git a/examples/src/bin/rmt_rx.rs b/examples/src/bin/rmt_rx.rs index c9dc626b1b4..39869dd3741 100644 --- a/examples/src/bin/rmt_rx.rs +++ b/examples/src/bin/rmt_rx.rs @@ -56,17 +56,11 @@ fn main() -> ! { let delay = Delay::new(); - let mut data = [PulseCode { - level1: true, - length1: 1, - level2: false, - length2: 1, - }; 48]; + let mut data: [u32; 48] = [PulseCode::empty(); 48]; loop { for x in data.iter_mut() { - x.length1 = 0; - x.length2 = 0; + x.reset() } let transaction = channel.receive(&mut data).unwrap(); @@ -84,34 +78,34 @@ fn main() -> ! { channel = channel_res; let mut total = 0usize; for entry in &data[..data.len()] { - if entry.length1 == 0 { + if entry.length1() == 0 { break; } - total += entry.length1 as usize; + total += entry.length1() as usize; - if entry.length2 == 0 { + if entry.length2() == 0 { break; } - total += entry.length2 as usize; + total += entry.length2() as usize; } for entry in &data[..data.len()] { - if entry.length1 == 0 { + if entry.length1() == 0 { break; } - let count = WIDTH / (total / entry.length1 as usize); - let c = if entry.level1 { '-' } else { '_' }; + let count = WIDTH / (total / entry.length1() as usize); + let c = if entry.level1() { '-' } else { '_' }; for _ in 0..count + 1 { print!("{}", c); } - if entry.length2 == 0 { + if entry.length2() == 0 { break; } - let count = WIDTH / (total / entry.length2 as usize); - let c = if entry.level2 { '-' } else { '_' }; + let count = WIDTH / (total / entry.length2() as usize); + let c = if entry.level2() { '-' } else { '_' }; for _ in 0..count + 1 { print!("{}", c); } diff --git a/examples/src/bin/rmt_tx.rs b/examples/src/bin/rmt_tx.rs index d87e8d87a85..7689ce47c4a 100644 --- a/examples/src/bin/rmt_tx.rs +++ b/examples/src/bin/rmt_tx.rs @@ -43,23 +43,12 @@ fn main() -> ! { let delay = Delay::new(); - let mut data = [PulseCode { - level1: true, - length1: 200, - level2: false, - length2: 50, - }; 20]; - - data[data.len() - 2] = PulseCode { - level1: true, - length1: 3000, - level2: false, - length2: 500, - }; - data[data.len() - 1] = PulseCode::default(); + let mut data = [PulseCode::new(true, 200, false, 50); 20]; + data[data.len() - 2] = PulseCode::new(true, 3000, false, 500); + data[data.len() - 1] = PulseCode::empty(); loop { - let transaction = channel.transmit(&data); + let transaction = channel.transmit(&data).unwrap(); channel = transaction.wait().unwrap(); delay.delay_millis(500); } diff --git a/hil-test/tests/rmt.rs b/hil-test/tests/rmt.rs index c7c9db6d28e..5f014745b0a 100644 --- a/hil-test/tests/rmt.rs +++ b/hil-test/tests/rmt.rs @@ -14,6 +14,8 @@ use hil_test as _; #[cfg(test)] #[embedded_test::tests] mod tests { + use esp_hal::rmt::Error; + use super::*; #[init] @@ -76,30 +78,15 @@ mod tests { } } - let mut tx_data = [PulseCode { - level1: true, - length1: 200, - level2: false, - length2: 50, - }; 20]; - - tx_data[tx_data.len() - 2] = PulseCode { - level1: true, - length1: 3000, - level2: false, - length2: 500, - }; - tx_data[tx_data.len() - 1] = PulseCode::default(); + let mut tx_data = [PulseCode::new(true, 200, false, 50); 20]; + + tx_data[tx_data.len() - 2] = PulseCode::new(true, 3000, false, 500); + tx_data[tx_data.len() - 1] = PulseCode::empty(); - let mut rcv_data = [PulseCode { - level1: false, - length1: 0, - level2: false, - length2: 0, - }; 20]; + let mut rcv_data: [u32; 20] = [PulseCode::empty(); 20]; let rx_transaction = rx_channel.receive(&mut rcv_data).unwrap(); - let tx_transaction = tx_channel.transmit(&tx_data); + let tx_transaction = tx_channel.transmit(&tx_data).unwrap(); rx_transaction.wait().unwrap(); tx_transaction.wait().unwrap(); @@ -108,4 +95,41 @@ mod tests { // they can't be equal assert_eq!(&tx_data[..18], &rcv_data[..18]); } + + #[test] + #[timeout(1)] + fn rmt_single_shot_fails_without_end_marker() { + let peripherals = esp_hal::init(esp_hal::Config::default()); + + let io = Io::new(peripherals.GPIO, peripherals.IO_MUX); + + cfg_if::cfg_if! { + if #[cfg(feature = "esp32h2")] { + let freq = 32.MHz(); + } else { + let freq = 80.MHz(); + } + }; + + let rmt = Rmt::new(peripherals.RMT, freq).unwrap(); + + let (_, tx) = hil_test::common_test_pins!(io); + + let tx_config = TxChannelConfig { + clk_divider: 255, + ..TxChannelConfig::default() + }; + + let tx_channel = { + use esp_hal::rmt::TxChannelCreator; + rmt.channel0.configure(tx, tx_config).unwrap() + }; + + let tx_data = [PulseCode::new(true, 200, false, 50); 20]; + + let tx_transaction = tx_channel.transmit(&tx_data); + + assert!(tx_transaction.is_err()); + assert!(matches!(tx_transaction, Err(Error::EndMarkerMissing))); + } }