diff --git a/examples/apps/src/ble_bas_central.rs b/examples/apps/src/ble_bas_central.rs index ba7d91f6..b04f496b 100644 --- a/examples/apps/src/ble_bas_central.rs +++ b/examples/apps/src/ble_bas_central.rs @@ -1,10 +1,12 @@ -use embassy_futures::join::join; +use embassy_futures::join::{join, join3}; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_time::{Duration, Timer}; use static_cell::StaticCell; use trouble_host::attribute::Uuid; use trouble_host::connection::ConnectConfig; +use trouble_host::packet_pool::PacketPool; use trouble_host::scan::ScanConfig; -use trouble_host::{Address, BleHost, BleHostResources, Controller, PacketQos}; +use trouble_host::{AddrKind, Address, BdAddr, BleHost, BleHostResources, Controller, PacketQos}; /// Size of L2CAP packets const L2CAP_MTU: usize = 128; @@ -39,27 +41,53 @@ where info!("Scanning for peripheral..."); let _ = join(ble.run(), async { + static PACKET_POOL: StaticCell> = StaticCell::new(); + let packet_pool = PACKET_POOL.init(PacketPool::new(PacketQos::None)); + + info!("Connecting"); + let conn = ble.connect(&config).await.unwrap(); info!("Connected, creating gatt client"); - let mut client = ble.gatt_client::<10, 128>(&conn).await.unwrap(); + let mut client = ble.gatt_client::<10, 64, 16, 24>(&conn, packet_pool).await.unwrap(); + + let _ = join(client.task(), async { + info!("Looking for battery service"); + let services = client.services_by_uuid(&Uuid::new_short(0x180f)).await.unwrap(); + let service = services.first().unwrap().clone(); - info!("Looking for battery service"); - let services = client.services_by_uuid(&Uuid::new_short(0x180f)).await.unwrap(); - let service = services.first().unwrap().clone(); + info!("Looking for value handle"); + let c = client + .characteristic_by_uuid(&service, &Uuid::new_short(0x2a19)) + .await + .unwrap(); - info!("Looking for value handle"); - let c = client - .characteristic_by_uuid(&service, &Uuid::new_short(0x2a19)) - .await - .unwrap(); + info!("Subscribing notifications"); + let mut listener = client.subscribe(&c, false).await.unwrap(); - loop { - let mut data = [0; 1]; - client.read_characteristic(&c, &mut data[..]).await.unwrap(); - info!("Read value: {}", data[0]); - Timer::after(Duration::from_secs(10)).await; - } + let _ = join( + async { + loop { + let mut data = [0; 1]; + client.read_characteristic(&c, &mut data[..]).await.unwrap(); + info!("Read value: {}", data[0]); + Timer::after(Duration::from_secs(10)).await; + } + }, + async { + loop { + let (len, data) = listener.next().await; + defmt::info!( + "Got notification: {:x} (val: {})", + &data.as_ref()[..len as usize], + data.as_ref()[0] + ); + } + }, + ) + .await; + }) + .await; }) .await; } diff --git a/examples/rp-pico-w/Cargo.toml b/examples/rp-pico-w/Cargo.toml index 8eca8f9f..2592edd5 100644 --- a/examples/rp-pico-w/Cargo.toml +++ b/examples/rp-pico-w/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" resolver = "2" [dependencies] -embassy-executor = { version = "0.6", default-features = false, features = ["task-arena-size-65536", "arch-cortex-m", "executor-thread", "defmt", "integrated-timers", "executor-interrupt"] } +embassy-executor = { version = "0.6", default-features = false, features = ["task-arena-size-98304", "arch-cortex-m", "executor-thread", "defmt", "integrated-timers", "executor-interrupt"] } embassy-time = { version = "0.3.0", default-features = false, features = ["defmt", "defmt-timestamp-uptime"] } embassy-rp = { version = "0.2.0", features = ["defmt", "unstable-pac", "time-driver", "critical-section-impl", "rp2040"] } embassy-futures = "0.1.1" diff --git a/examples/rp-pico-w/src/bin/ble_bas_central.rs b/examples/rp-pico-w/src/bin/ble_bas_central.rs new file mode 100644 index 00000000..8c0da46f --- /dev/null +++ b/examples/rp-pico-w/src/bin/ble_bas_central.rs @@ -0,0 +1,56 @@ +#![no_std] +#![no_main] + +use bt_hci::controller::ExternalController; +use cyw43_pio::PioSpi; +use defmt::*; +use embassy_executor::Spawner; +use embassy_rp::bind_interrupts; +use embassy_rp::gpio::{Level, Output}; +use embassy_rp::peripherals::{DMA_CH0, PIO0}; +use embassy_rp::pio::{InterruptHandler, Pio}; +use static_cell::StaticCell; +use trouble_example_apps::ble_bas_central; +use {defmt_rtt as _, embassy_time as _, panic_probe as _}; + +bind_interrupts!(struct Irqs { + PIO0_IRQ_0 => InterruptHandler; +}); + +#[embassy_executor::task] +async fn cyw43_task(runner: cyw43::Runner<'static, Output<'static>, PioSpi<'static, PIO0, 0, DMA_CH0>>) -> ! { + runner.run().await +} + +#[embassy_executor::main] +async fn main(spawner: Spawner) { + let p = embassy_rp::init(Default::default()); + + // + // IMPORTANT + // + // Download and make sure these files from https://github.com/embassy-rs/embassy/tree/main/cyw43-firmware + // are available in the below path. + // + // IMPORTANT + // + + let fw = include_bytes!("../../cyw43-firmware/43439A0.bin"); + let clm = include_bytes!("../../cyw43-firmware/43439A0_clm.bin"); + let btfw = include_bytes!("../../cyw43-firmware/43439A0_btfw.bin"); + + let pwr = Output::new(p.PIN_23, Level::Low); + let cs = Output::new(p.PIN_25, Level::High); + let mut pio = Pio::new(p.PIO0, Irqs); + let spi = PioSpi::new(&mut pio.common, pio.sm0, pio.irq0, cs, p.PIN_24, p.PIN_29, p.DMA_CH0); + + static STATE: StaticCell = StaticCell::new(); + let state = STATE.init(cyw43::State::new()); + let (_net_device, bt_device, mut control, runner) = cyw43::new_with_bluetooth(state, pwr, spi, fw, btfw).await; + unwrap!(spawner.spawn(cyw43_task(runner))); + control.init(clm).await; + + let controller: ExternalController<_, 10> = ExternalController::new(bt_device); + + ble_bas_central::run(controller).await; +} diff --git a/host/src/attribute.rs b/host/src/attribute.rs index 33f95863..44d41c44 100644 --- a/host/src/attribute.rs +++ b/host/src/attribute.rs @@ -5,7 +5,7 @@ use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::blocking_mutex::Mutex; use crate::att::AttErrorCode; -use crate::cursor::WriteCursor; +use crate::cursor::{ReadCursor, WriteCursor}; pub use crate::types::uuid::Uuid; use crate::Error; @@ -211,6 +211,15 @@ impl<'d> AttributeData<'d> { _ => Err(AttErrorCode::WriteNotPermitted), } } + + pub fn decode_declaration(data: &[u8]) -> Result { + let mut r = ReadCursor::new(data); + Ok(Self::Declaration { + props: CharacteristicProps(r.read()?), + handle: r.read()?, + uuid: Uuid::from_slice(r.remaining()), + }) + } } impl<'a> fmt::Debug for Attribute<'a> { @@ -564,7 +573,7 @@ impl From<[CharacteristicProp; T]> for CharacteristicProps { } impl CharacteristicProps { - fn any(&self, props: &[CharacteristicProp]) -> bool { + pub fn any(&self, props: &[CharacteristicProp]) -> bool { for p in props { if (*p as u8) & self.0 != 0 { return true; @@ -579,3 +588,32 @@ pub struct AttributeValue<'d, M: RawMutex> { } impl<'d, M: RawMutex> AttributeValue<'d, M> {} + +#[derive(Clone, Copy)] +pub enum CCCDFlag { + Notify = 0x1, + Indicate = 0x2, +} + +pub struct CCCD(pub(crate) u16); + +impl From<[CCCDFlag; T]> for CCCD { + fn from(props: [CCCDFlag; T]) -> Self { + let mut val: u16 = 0; + for prop in props { + val |= prop as u16; + } + CCCD(val) + } +} + +impl CCCD { + pub fn any(&self, props: &[CCCDFlag]) -> bool { + for p in props { + if (*p as u16) & self.0 != 0 { + return true; + } + } + false + } +} diff --git a/host/src/gatt.rs b/host/src/gatt.rs index 8e2c4fc8..381c262c 100644 --- a/host/src/gatt.rs +++ b/host/src/gatt.rs @@ -1,16 +1,24 @@ +use core::cell::{Ref, RefCell}; +use core::future::Future; +use core::marker::PhantomData; + use bt_hci::controller::Controller; use bt_hci::param::ConnHandle; -use embassy_sync::blocking_mutex::raw::RawMutex; -use embassy_sync::channel::{DynamicReceiver, DynamicSender}; +use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; +use embassy_sync::channel::{Channel, DynamicReceiver, DynamicSender}; use heapless::Vec; use crate::att::{self, AttReq, AttRsp, ATT_HANDLE_VALUE_NTF}; -use crate::attribute::{Characteristic, Uuid, CHARACTERISTIC_UUID16, PRIMARY_SERVICE_UUID16}; +use crate::attribute::{ + AttributeData, Characteristic, CharacteristicProp, Uuid, CCCD, CHARACTERISTIC_CCCD_UUID16, CHARACTERISTIC_UUID16, + PRIMARY_SERVICE_UUID16, +}; use crate::attribute_server::AttributeServer; use crate::connection::Connection; use crate::connection_manager::DynamicConnectionManager; use crate::cursor::{ReadCursor, WriteCursor}; use crate::host::BleHost; +use crate::packet_pool::{GlobalPacketPool, Packet, PacketPool, ATT_ID}; use crate::pdu::Pdu; use crate::types::l2cap::L2capHeader; use crate::{BleHostError, Error}; @@ -138,11 +146,49 @@ pub enum GattEvent<'reference> { }, } -pub struct GattClient<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: usize = 27> { - pub(crate) services: Vec, +pub struct NotificationListener<'lst> { + pub(crate) listener: DynamicReceiver<'lst, (u16, Packet)>, +} + +impl<'lst> NotificationListener<'lst> { + #[allow(clippy::should_implement_trait)] + /// Get the next (len: u16, Packet) tuple from the rx queue + pub fn next(&mut self) -> impl Future + '_ { + self.listener.receive() + } +} + +pub struct NotificationManager< + 'mgr, + E, + C: Client, + const MAX_NOTIF: usize, + const NOTIF_QSIZE: usize, + const ATT_MTU: usize, +> { + pub(crate) client: &'mgr C, + pub(crate) rx: DynamicReceiver<'mgr, (ConnHandle, Pdu)>, + _e: PhantomData, +} + +pub struct GattClient< + 'reference, + 'resources, + T: Controller, + const MAX_SERVICES: usize, + const MAX_NOTIF: usize, + const NOTIF_QSIZE: usize, + const L2CAP_MTU: usize = 27, +> { + pub(crate) known_services: RefCell>, pub(crate) rx: DynamicReceiver<'reference, (ConnHandle, Pdu)>, pub(crate) ble: &'reference BleHost<'resources, T>, pub(crate) connection: Connection<'reference>, + pub(crate) request_channel: Channel, + + pub(crate) notification_pool: &'static PacketPool, + pub(crate) notification_map: RefCell<[Option; MAX_NOTIF]>, + pub(crate) notification_channels: [Channel; MAX_NOTIF], } #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -153,10 +199,21 @@ pub struct ServiceHandle { uuid: Uuid, } -impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: usize> - GattClient<'reference, 'resources, T, MAX, L2CAP_MTU> +pub trait Client { + fn request(&self, req: AttReq<'_>) -> impl Future>>; +} + +impl< + 'reference, + 'resources, + T: Controller, + const MAX_SERVICES: usize, + const MAX_NOTIF: usize, + const NOTIF_QSIZE: usize, + const L2CAP_MTU: usize, + > Client for GattClient<'reference, 'resources, T, MAX_SERVICES, MAX_NOTIF, NOTIF_QSIZE, L2CAP_MTU> { - async fn request(&mut self, req: AttReq<'_>) -> Result> { + async fn request(&self, req: AttReq<'_>) -> Result> { let header = L2capHeader { channel: crate::types::l2cap::L2CAP_CID_ATT, length: req.size() as u16, @@ -170,14 +227,30 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u let mut grant = self.ble.acl(self.connection.handle(), 1).await?; grant.send(w.finish()).await?; - let (h, pdu) = self.rx.receive().await; + let (h, pdu) = self.request_channel.receive().await; + assert_eq!(h, self.connection.handle()); Ok(pdu) } +} +impl< + 'reference, + 'resources, + T: Controller, + const MAX_SERVICES: usize, + const MAX_NOTIF: usize, + const NOTIF_QSIZE: usize, + const L2CAP_MTU: usize, + > GattClient<'reference, 'resources, T, MAX_SERVICES, MAX_NOTIF, NOTIF_QSIZE, L2CAP_MTU> +{ /// Discover primary services associated with a UUID. - pub async fn services_by_uuid(&mut self, uuid: &Uuid) -> Result<&[ServiceHandle], BleHostError> { + pub async fn services_by_uuid( + &self, + uuid: &Uuid, + ) -> Result, BleHostError> { let mut start: u16 = 0x0001; + let mut result = Vec::new(); loop { let data = att::AttReq::FindByTypeValue { @@ -200,13 +273,16 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u while let Some(res) = it.next() { let (handle, e) = res?; end = e; - self.services - .push(ServiceHandle { - start: handle, - end, - uuid: uuid.clone(), - }) - .unwrap(); + let svc = ServiceHandle { + start: handle, + end, + uuid: uuid.clone(), + }; + result.push(svc.clone()).map_err(|_| Error::InsufficientSpace)?; + self.known_services + .borrow_mut() + .push(svc) + .map_err(|_| Error::InsufficientSpace)?; } if end == 0xFFFF { break; @@ -219,12 +295,12 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u } } - Ok(&self.services[..]) + Ok(result) } /// Discover characteristics in a given service using a UUID. pub async fn characteristic_by_uuid( - &mut self, + &self, service: &ServiceHandle, uuid: &Uuid, ) -> Result> { @@ -243,22 +319,31 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u if item.len() < 5 { return Err(Error::InvalidValue.into()); } - let mut r = ReadCursor::new(item); - let _props: u8 = r.read()?; - let value_handle: u16 = r.read()?; - let value_uuid: Uuid = Uuid::from_slice(r.remaining()); - - if uuid == &value_uuid { - return Ok(Characteristic { - handle: value_handle, - cccd_handle: None, - }); - } + if let AttributeData::Declaration { + props, + handle, + uuid: decl_uuid, + } = AttributeData::decode_declaration(item)? + { + if *uuid == decl_uuid { + // "notify" and "indicate" characteristic properties + let cccd_handle = + if props.any(&[CharacteristicProp::Indicate, CharacteristicProp::Notify]) { + Some(self.get_characteristic_cccd(handle).await?.0) + } else { + None + }; + + return Ok(Characteristic { handle, cccd_handle }); + } - if handle == 0xFFFF { - return Err(Error::NotFound.into()); + if handle == 0xFFFF { + return Err(Error::NotFound.into()); + } + start = handle + 1; + } else { + return Err(Error::InvalidValue.into()); } - start = handle + 1; } } AttRsp::Error { request, handle, code } => return Err(Error::Att(code).into()), @@ -269,11 +354,36 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u } } + async fn get_characteristic_cccd(&self, char_handle: u16) -> Result<(u16, CCCD), BleHostError> { + let data = att::AttReq::ReadByType { + start: char_handle, + end: char_handle + 1, + attribute_type: CHARACTERISTIC_CCCD_UUID16, + }; + + let pdu = self.request(data).await?; + + match AttRsp::decode(pdu.as_ref())? { + AttRsp::ReadByType { mut it } => { + if let Some(Ok((handle, item))) = it.next() { + Ok(( + handle, + CCCD(u16::from_le_bytes(item.try_into().map_err(|_| Error::OutOfMemory)?)), + )) + } else { + Err(Error::NotFound.into()) + } + } + AttRsp::Error { request, handle, code } => Err(Error::Att(code).into()), + _ => Err(Error::InvalidValue.into()), + } + } + /// Read a characteristic described by a handle. /// /// The number of bytes copied into the provided buffer is returned. pub async fn read_characteristic( - &mut self, + &self, characteristic: &Characteristic, dest: &mut [u8], ) -> Result> { @@ -344,4 +454,111 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u _ => Err(Error::InvalidValue.into()), } } + + /// Subscribe to indication/notification of a given Characteristic + /// + /// A listener is returned, which has a `next()` method + pub async fn subscribe( + &self, + characteristic: &Characteristic, + indication: bool, + ) -> Result> { + let properties = u16::to_le_bytes(if indication { 0x02 } else { 0x01 }); + + let data = att::AttReq::Write { + handle: characteristic.cccd_handle.ok_or(Error::NotSupported)?, + data: &properties, + }; + + // set the CCCD + let pdu = self.request(data).await?; + + match AttRsp::decode(pdu.as_ref())? { + AttRsp::Write => { + // look for a free slot in the n_channel -> handle array + for (n, item) in self.notification_map.borrow_mut().iter_mut().enumerate() { + if item.is_none() { + item.replace(characteristic.handle); + return Ok(NotificationListener { + listener: self.notification_channels[n].dyn_receiver(), + }); + } + } + // otherwise, there's no space left in the array + Err(Error::InsufficientSpace.into()) + } + AttRsp::Error { request, handle, code } => Err(Error::Att(code).into()), + _ => Err(Error::InvalidValue.into()), + } + } + + /// Unsubscribe from a given Characteristic + pub async fn unsubscribe(&self, characteristic: &Characteristic) -> Result<(), BleHostError> { + let mut notifications = self.notification_map.borrow_mut(); + let (item, n) = notifications + .iter_mut() + .enumerate() + .find_map(|(n, item)| { + if let Some(h) = item { + if *h == characteristic.handle { + Some((item, n)) + } else { + None + } + } else { + None + } + }) + .ok_or(Error::NotFound)?; + + // Free up the slot in the n_channel -> handle map + item.take(); + // Clear any data queued up in the channel + self.notification_channels[n].clear(); + Ok(()) + } + + pub async fn handle_notification_packet(&'reference self, data: &[u8]) -> Result<(), BleHostError> { + let mut r = ReadCursor::new(data); + let value_handle: u16 = r.read()?; + let value_attr = r.remaining(); + + // let's find the corresponding `n` first, to avoid retaining the borrow_mut() across an await point + let found_n = self + .notification_map + .borrow_mut() + .iter() + .enumerate() + .find_map(|(n, item)| { + if let Some(handle) = item { + if *handle == value_handle { + return Some(n); + } + } + None + }); + + if let Some(n) = found_n { + let mut packet = self.notification_pool.alloc(ATT_ID).ok_or(Error::InsufficientSpace)?; + let len = value_attr.len(); + packet.as_mut()[..len].copy_from_slice(value_attr); + self.notification_channels[n].send((len as u16, packet)).await; + } + Ok(()) + } + + /// Task which handles GATT rx data (needed for notifications to work) + pub async fn task(&self) -> Result<(), BleHostError> { + loop { + let (handle, pdu) = self.rx.receive().await; + let data = pdu.as_ref(); + + // handle notifications + if data[0] == ATT_HANDLE_VALUE_NTF { + self.handle_notification_packet(&data[1..]).await?; + } else { + self.request_channel.send((handle, pdu)).await; + } + } + } } diff --git a/host/src/host.rs b/host/src/host.rs index 43eca225..9912332e 100644 --- a/host/src/host.rs +++ b/host/src/host.rs @@ -759,10 +759,18 @@ where /// Creates a GATT client capable of processing the GATT protocol using the provided table of attributes. #[cfg(feature = "gatt")] - pub async fn gatt_client<'reference, const MAX: usize, const L2CAP_MTU: usize>( + pub async fn gatt_client< + 'reference, + const MAX_SERVICES: usize, + const MAX_NOTIF: usize, + const NOTIF_QSIZE: usize, + const L2CAP_MTU: usize, + >( &'reference self, connection: &Connection<'reference>, - ) -> Result, BleHostError> { + notification_pool: &'static PacketPool, + ) -> Result, BleHostError> + { let l2cap = L2capHeader { channel: 4, length: 3 }; let mut buf = [0; 7]; let mut w = WriteCursor::new(&mut buf); @@ -776,10 +784,16 @@ where grant.send(w.finish()).await?; Ok(GattClient { - services: heapless::Vec::new(), + known_services: RefCell::new(heapless::Vec::new()), rx: self.att_inbound.receiver().into(), ble: self, connection: connection.clone(), + + request_channel: Channel::new(), + + notification_pool, + notification_map: RefCell::new([const { None }; MAX_NOTIF]), + notification_channels: [const { Channel::new() }; MAX_NOTIF], }) } diff --git a/host/src/lib.rs b/host/src/lib.rs index c2a6b398..ad7fffd9 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -23,7 +23,7 @@ mod command; pub mod config; mod connection_manager; mod cursor; -mod packet_pool; +pub mod packet_pool; mod pdu; pub mod types; diff --git a/host/tests/gatt.rs b/host/tests/gatt.rs index 933601d6..9f903ed0 100644 --- a/host/tests/gatt.rs +++ b/host/tests/gatt.rs @@ -160,7 +160,7 @@ async fn gatt_client_server() { tokio::time::sleep(Duration::from_secs(5)).await; println!("[central] creating gatt client"); - let mut client = adapter.gatt_client::<10, 128>(&conn).await.unwrap(); + let mut client = adapter.gatt_client::<10, 128, 10, 24>(&conn).await.unwrap(); println!("[central] discovering services"); let services = client.services_by_uuid(&SERVICE_UUID).await.unwrap();