Skip to content

Commit

Permalink
Fail RMT one-shot transactions if end-marker is missing (#2463)
Browse files Browse the repository at this point in the history
* Fail RMT one-shot transactions if end-marker is missing

* CHANGELOG.md

* Add test

* Fix

* Fix

* RMT: use u32, turn PulseCode into a convenience trait

* Clippy

* Adapt test
  • Loading branch information
bjoernQ authored Nov 13, 2024
1 parent 8cbc249 commit 7da4444
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 154 deletions.
1 change: 1 addition & 0 deletions esp-hal/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `slave::Spi` constructors no longer take pins (#2485)
- The `I2c` master driver has been moved from `esp_hal::i2c` to `esp_hal::i2c::master`. (#2476)
- `I2c` SCL timeout is now defined in bus clock cycles. (#2477)
- Trying to send a single-shot RMT transmission will result in an error now, `RMT` deals with `u32` now, `PulseCode` is a convenience trait now (#2463)

### Fixed

Expand Down
28 changes: 28 additions & 0 deletions esp-hal/MIGRATING-0.21.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,31 @@ If you were using an 16-bit bus, you don't need to change anything, `set_byte_or

If you were sharing the bus between an 8-bit and 16-bit device, you will have to call the corresponding method when
you switch between devices. Be sure to read the documentation of the new methods.


## `rmt::Channel::transmit` now returns `Result`, `PulseCode` is now `u32`

When trying to send a one-shot transmission will fail if it doesn't end with an end-marker.

```diff
- let mut data = [PulseCode {
- level1: true,
- length1: 200,
- level2: false,
- length2: 50,
- }; 20];
-
- data[data.len() - 2] = PulseCode {
- level1: true,
- length1: 3000,
- level2: false,
- length2: 500,
- };
- data[data.len() - 1] = PulseCode::default();
+ let mut data = [PulseCode::new(true, 200, false, 50); 20];
+ data[data.len() - 2] = PulseCode::new(true, 3000, false, 500);
+ data[data.len() - 1] = PulseCode::empty();

- let transaction = channel.transmit(&data);
+ let transaction = channel.transmit(&data).unwrap();
```
145 changes: 74 additions & 71 deletions esp-hal/src/rmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,61 +112,62 @@ pub enum Error {
InvalidArgument,
/// An error occurred during transmission
TransmissionError,
/// No transmission end marker found
EndMarkerMissing,
}

/// Convenience representation of a pulse code entry.
///
/// Allows for the assignment of two levels and their lengths
#[derive(Clone, Copy, Debug, Default, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct PulseCode {
/// Convenience trait to work with pulse codes.
pub trait PulseCode: crate::private::Sealed {
/// Create a new instance
fn new(level1: bool, length1: u16, level2: bool, length2: u16) -> Self;

/// Create a new empty instance
fn empty() -> Self;

/// Set all levels and lengths to 0
fn reset(&mut self);

/// Logical output level in the first pulse code interval
pub level1: bool,
fn level1(&self) -> bool;

/// Length of the first pulse code interval (in clock cycles)
pub length1: u16,
fn length1(&self) -> u16;

/// Logical output level in the second pulse code interval
pub level2: bool,
fn level2(&self) -> bool;

/// Length of the second pulse code interval (in clock cycles)
pub length2: u16,
fn length2(&self) -> u16;
}

impl From<u32> for PulseCode {
fn from(value: u32) -> Self {
Self {
level1: value & (1 << 15) != 0,
length1: (value & 0b111_1111_1111_1111) as u16,
level2: value & (1 << 31) != 0,
length2: ((value >> 16) & 0b111_1111_1111_1111) as u16,
}
impl PulseCode for u32 {
fn new(level1: bool, length1: u16, level2: bool, length2: u16) -> Self {
(((level1 as u32) << 15) | length1 as u32 & 0b111_1111_1111_1111)
| (((level2 as u32) << 15) | length2 as u32 & 0b111_1111_1111_1111) << 16
}
}

/// Convert a pulse code structure into a u32 value that can be written
/// into the data registers
impl From<PulseCode> for u32 {
#[inline(always)]
fn from(p: PulseCode) -> u32 {
// The length1 value resides in bits [14:0]
let mut entry: u32 = p.length1 as u32;

// If level1 is high, set bit 15, otherwise clear it
if p.level1 {
entry |= 1 << 15;
} else {
entry &= !(1 << 15);
}
fn empty() -> Self {
0
}

// If level2 is high, set bit 31, otherwise clear it
if p.level2 {
entry |= 1 << 31;
} else {
entry &= !(1 << 31);
}
fn reset(&mut self) {
*self = 0
}

// The length2 value resides in bits [30:16]
entry |= (p.length2 as u32) << 16;
fn level1(&self) -> bool {
self & (1 << 15) != 0
}

fn length1(&self) -> u16 {
(self & 0b111_1111_1111_1111) as u16
}

fn level2(&self) -> bool {
self & (1 << 31) != 0
}

entry
fn length2(&self) -> u16 {
((self >> 16) & 0b111_1111_1111_1111) as u16
}
}

Expand Down Expand Up @@ -423,16 +424,16 @@ where
}

/// An in-progress transaction for a single shot TX transaction.
pub struct SingleShotTxTransaction<'a, C, T: Into<u32> + Copy>
pub struct SingleShotTxTransaction<'a, C>
where
C: TxChannel,
{
channel: C,
index: usize,
data: &'a [T],
data: &'a [u32],
}

impl<C, T: Into<u32> + Copy> SingleShotTxTransaction<'_, C, T>
impl<C> SingleShotTxTransaction<'_, C>
where
C: TxChannel,
{
Expand Down Expand Up @@ -466,7 +467,7 @@ where
.enumerate()
{
unsafe {
ptr.add(idx).write_volatile((*entry).into());
ptr.add(idx).write_volatile(*entry);
}
}

Expand Down Expand Up @@ -982,26 +983,23 @@ pub trait TxChannel: TxChannelInternal<Blocking> {
/// This returns a [`SingleShotTxTransaction`] which can be used to wait for
/// the transaction to complete and get back the channel for further
/// use.
fn transmit<T: Into<u32> + Copy>(self, data: &[T]) -> SingleShotTxTransaction<'_, Self, T>
fn transmit(self, data: &[u32]) -> Result<SingleShotTxTransaction<'_, Self>, Error>
where
Self: Sized,
{
let index = Self::send_raw(data, false, 0);
SingleShotTxTransaction {
let index = Self::send_raw(data, false, 0)?;
Ok(SingleShotTxTransaction {
channel: self,
index,
data,
}
})
}

/// Start transmitting the given pulse code continuously.
/// This returns a [`ContinuousTxTransaction`] which can be used to stop the
/// ongoing transmission and get back the channel for further use.
/// The length of sequence cannot exceed the size of the allocated RMT RAM.
fn transmit_continuously<T: Into<u32> + Copy>(
self,
data: &[T],
) -> Result<ContinuousTxTransaction<Self>, Error>
fn transmit_continuously(self, data: &[u32]) -> Result<ContinuousTxTransaction<Self>, Error>
where
Self: Sized,
{
Expand All @@ -1011,10 +1009,10 @@ pub trait TxChannel: TxChannelInternal<Blocking> {
/// Like [`Self::transmit_continuously`] but also sets a loop count.
/// [`ContinuousTxTransaction`] can be used to check if the loop count is
/// reached.
fn transmit_continuously_with_loopcount<T: Into<u32> + Copy>(
fn transmit_continuously_with_loopcount(
self,
loopcount: u16,
data: &[T],
data: &[u32],
) -> Result<ContinuousTxTransaction<Self>, Error>
where
Self: Sized,
Expand All @@ -1023,21 +1021,21 @@ pub trait TxChannel: TxChannelInternal<Blocking> {
return Err(Error::Overflow);
}

let _index = Self::send_raw(data, true, loopcount);
let _index = Self::send_raw(data, true, loopcount)?;
Ok(ContinuousTxTransaction { channel: self })
}
}

/// RX transaction instance
pub struct RxTransaction<'a, C, T: From<u32> + Copy>
pub struct RxTransaction<'a, C>
where
C: RxChannel,
{
channel: C,
data: &'a mut [T],
data: &'a mut [u32],
}

impl<C, T: From<u32> + Copy> RxTransaction<'_, C, T>
impl<C> RxTransaction<'_, C>
where
C: RxChannel,
{
Expand All @@ -1062,7 +1060,7 @@ where
as *mut u32;
let len = self.data.len();
for (idx, entry) in self.data.iter_mut().take(len).enumerate() {
*entry = unsafe { ptr.add(idx).read_volatile().into() };
*entry = unsafe { ptr.add(idx).read_volatile() };
}

Ok(self.channel)
Expand All @@ -1075,10 +1073,7 @@ pub trait RxChannel: RxChannelInternal<Blocking> {
/// This returns a [RxTransaction] which can be used to wait for receive to
/// complete and get back the channel for further use.
/// The length of the received data cannot exceed the allocated RMT RAM.
fn receive<T: From<u32> + Copy>(
self,
data: &mut [T],
) -> Result<RxTransaction<'_, Self, T>, Error>
fn receive(self, data: &mut [u32]) -> Result<RxTransaction<'_, Self>, Error>
where
Self: Sized,
{
Expand Down Expand Up @@ -1143,7 +1138,7 @@ pub trait TxChannelAsync: TxChannelInternal<Async> {
/// Start transmitting the given pulse code sequence.
/// The length of sequence cannot exceed the size of the allocated RMT
/// RAM.
async fn transmit<'a, T: Into<u32> + Copy>(&mut self, data: &'a [T]) -> Result<(), Error>
async fn transmit<'a>(&mut self, data: &'a [u32]) -> Result<(), Error>
where
Self: Sized,
{
Expand All @@ -1154,7 +1149,7 @@ pub trait TxChannelAsync: TxChannelInternal<Async> {
Self::clear_interrupts();
Self::listen_interrupt(Event::End);
Self::listen_interrupt(Event::Error);
Self::send_raw(data, false, 0);
Self::send_raw(data, false, 0)?;

RmtTxFuture::new(self).await;

Expand Down Expand Up @@ -1402,9 +1397,17 @@ where

fn is_loopcount_interrupt_set() -> bool;

fn send_raw<T: Into<u32> + Copy>(data: &[T], continuous: bool, repeat: u16) -> usize {
fn send_raw(data: &[u32], continuous: bool, repeat: u16) -> Result<usize, Error> {
Self::clear_interrupts();

if let Some(last) = data.last() {
if !continuous && last.length2() != 0 && last.length1() != 0 {
return Err(Error::EndMarkerMissing);
}
} else {
return Err(Error::InvalidArgument);
}

let ptr = (constants::RMT_RAM_START
+ Self::CHANNEL as usize * constants::RMT_CHANNEL_RAM_SIZE * 4)
as *mut u32;
Expand All @@ -1414,7 +1417,7 @@ where
.enumerate()
{
unsafe {
ptr.add(idx).write_volatile((*entry).into());
ptr.add(idx).write_volatile(*entry);
}
}

Expand All @@ -1428,9 +1431,9 @@ where
Self::update();

if data.len() >= constants::RMT_CHANNEL_RAM_SIZE {
constants::RMT_CHANNEL_RAM_SIZE
Ok(constants::RMT_CHANNEL_RAM_SIZE)
} else {
data.len()
Ok(data.len())
}
}

Expand Down
27 changes: 11 additions & 16 deletions examples/src/bin/embassy_rmt_rx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,46 +73,41 @@ async fn main(spawner: Spawner) {
.spawn(signal_task(Output::new(peripherals.GPIO5, Level::Low)))
.unwrap();

let mut data = [PulseCode {
level1: true,
length1: 1,
level2: false,
length2: 1,
}; 48];
let mut data: [u32; 48] = [PulseCode::empty(); 48];

loop {
println!("receive");
channel.receive(&mut data).await.unwrap();
let mut total = 0usize;
for entry in &data[..data.len()] {
if entry.length1 == 0 {
if entry.length1() == 0 {
break;
}
total += entry.length1 as usize;
total += entry.length1() as usize;

if entry.length2 == 0 {
if entry.length2() == 0 {
break;
}
total += entry.length2 as usize;
total += entry.length2() as usize;
}

for entry in &data[..data.len()] {
if entry.length1 == 0 {
if entry.length1() == 0 {
break;
}

let count = WIDTH / (total / entry.length1 as usize);
let c = if entry.level1 { '-' } else { '_' };
let count = WIDTH / (total / entry.length1() as usize);
let c = if entry.level1() { '-' } else { '_' };
for _ in 0..count + 1 {
print!("{}", c);
}

if entry.length2 == 0 {
if entry.length2() == 0 {
break;
}

let count = WIDTH / (total / entry.length2 as usize);
let c = if entry.level2 { '-' } else { '_' };
let count = WIDTH / (total / entry.length2() as usize);
let c = if entry.level2() { '-' } else { '_' };
for _ in 0..count + 1 {
print!("{}", c);
}
Expand Down
Loading

0 comments on commit 7da4444

Please sign in to comment.