Skip to content

Commit

Permalink
Add possibility to listen to notifications
Browse files Browse the repository at this point in the history
  • Loading branch information
pferreir committed Sep 1, 2024
1 parent ad1c71e commit 1e512ab
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/apps/src/ble_bas_central.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where
let conn = ble.connect(&config).await.unwrap();
info!("Connected, creating gatt client");

let mut client = ble.gatt_client::<10, 128>(&conn).await.unwrap();
let mut client = ble.gatt_client::<10, 128, 16, 24>(&conn).await.unwrap();

info!("Looking for battery service");
let services = client.services_by_uuid(&Uuid::new_short(0x180f)).await.unwrap();
Expand Down
42 changes: 40 additions & 2 deletions host/src/attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use core::fmt;
use embassy_sync::blocking_mutex::raw::RawMutex;
use embassy_sync::blocking_mutex::Mutex;

use crate::att::AttErrorCode;
use crate::{att::AttErrorCode, cursor::ReadCursor};
use crate::cursor::WriteCursor;
pub use crate::types::uuid::Uuid;
use crate::Error;
Expand Down Expand Up @@ -211,6 +211,15 @@ impl<'d> AttributeData<'d> {
_ => Err(AttErrorCode::WriteNotPermitted),
}
}

pub fn decode_declaration(data: &[u8]) -> Result<Self, Error> {
let mut r = ReadCursor::new(data);
Ok(Self::Declaration {
props: CharacteristicProps(r.read()?),
handle: r.read()?,
uuid: Uuid::from_slice(r.remaining()),
})
}
}

impl<'a> fmt::Debug for Attribute<'a> {
Expand Down Expand Up @@ -564,7 +573,7 @@ impl<const T: usize> From<[CharacteristicProp; T]> for CharacteristicProps {
}

impl CharacteristicProps {
fn any(&self, props: &[CharacteristicProp]) -> bool {
pub fn any(&self, props: &[CharacteristicProp]) -> bool {
for p in props {
if (*p as u8) & self.0 != 0 {
return true;
Expand All @@ -579,3 +588,32 @@ pub struct AttributeValue<'d, M: RawMutex> {
}

impl<'d, M: RawMutex> AttributeValue<'d, M> {}

#[derive(Clone, Copy)]
pub enum CCCDFlag {
Notify = 0x1,
Indicate = 0x2
}

pub struct CCCD(pub(crate) u16);

impl<const T: usize> From<[CCCDFlag; T]> for CCCD {
fn from(props: [CCCDFlag; T]) -> Self {
let mut val: u16 = 0;
for prop in props {
val |= prop as u16;
}
CCCD(val)
}
}

impl CCCD {
pub fn any(&self, props: &[CCCDFlag]) -> bool {
for p in props {
if (*p as u16) & self.0 != 0 {
return true;
}
}
false
}
}
192 changes: 174 additions & 18 deletions host/src/gatt.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use core::cell::RefCell;
use core::future::Future;
use bt_hci::controller::Controller;
use bt_hci::param::ConnHandle;
use embassy_sync::blocking_mutex::raw::RawMutex;
use embassy_sync::channel::{DynamicReceiver, DynamicSender};
use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex};
use embassy_sync::channel::{Channel, DynamicReceiver, DynamicSender};
use heapless::Vec;

use crate::att::{self, AttReq, AttRsp, ATT_HANDLE_VALUE_NTF};
use crate::attribute::{Characteristic, Uuid, CHARACTERISTIC_UUID16, PRIMARY_SERVICE_UUID16};
use crate::attribute::{
AttributeData, Characteristic, CharacteristicProp, Uuid, CCCD, CHARACTERISTIC_CCCD_UUID16, CHARACTERISTIC_UUID16, PRIMARY_SERVICE_UUID16
};
use crate::attribute_server::AttributeServer;
use crate::connection::Connection;
use crate::connection_manager::DynamicConnectionManager;
Expand Down Expand Up @@ -138,11 +142,35 @@ pub enum GattEvent<'reference> {
},
}

pub struct GattClient<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: usize = 27> {
pub(crate) services: Vec<ServiceHandle, MAX>,
pub struct NotificationListener<'t, const MAX_NOTIF: usize, const NOTIF_QSIZE: usize, const ATT_MTU: usize> {
pub(crate) listener: DynamicReceiver<'t, Vec<u8, ATT_MTU>>,
}

impl<'t, const MAX_NOTIF: usize, const NOTIF_QSIZE: usize, const ATT_MTU: usize>
NotificationListener<'t, MAX_NOTIF, NOTIF_QSIZE, ATT_MTU>
{
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> impl Future<Output = Vec<u8, ATT_MTU>> + '_ {
self.listener.receive()
}
}

pub struct GattClient<
'reference,
'resources,
T: Controller,
const MAX_SERVICES: usize,
const MAX_NOTIF: usize,
const NOTIF_QSIZE: usize,
const L2CAP_MTU: usize = 27,
const ATT_MTU: usize = 23,
> {
pub(crate) services: Vec<ServiceHandle, MAX_SERVICES>,
pub(crate) rx: DynamicReceiver<'reference, (ConnHandle, Pdu)>,
pub(crate) ble: &'reference BleHost<'resources, T>,
pub(crate) connection: Connection<'reference>,
pub(crate) notifications_map: RefCell<[Option<u16>; MAX_NOTIF]>,
pub(crate) notifications_channels: [Channel<NoopRawMutex, Vec<u8, ATT_MTU>, NOTIF_QSIZE>; MAX_NOTIF],
}

#[cfg_attr(feature = "defmt", derive(defmt::Format))]
Expand All @@ -153,10 +181,18 @@ pub struct ServiceHandle {
uuid: Uuid,
}

impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: usize>
GattClient<'reference, 'resources, T, MAX, L2CAP_MTU>
impl<
'reference,
'resources,
T: Controller,
const MAX_SERVICES: usize,
const MAX_NOTIF: usize,
const NOTIF_QSIZE: usize,
const L2CAP_MTU: usize,
const ATT_MTU: usize,
> GattClient<'reference, 'resources, T, MAX_SERVICES, MAX_NOTIF, NOTIF_QSIZE, L2CAP_MTU, ATT_MTU>
{
async fn request(&mut self, req: AttReq<'_>) -> Result<Pdu, BleHostError<T::Error>> {
async fn request(&self, req: AttReq<'_>) -> Result<Pdu, BleHostError<T::Error>> {
let header = L2capHeader {
channel: crate::types::l2cap::L2CAP_CID_ATT,
length: req.size() as u16,
Expand Down Expand Up @@ -243,22 +279,31 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u
if item.len() < 5 {
return Err(Error::InvalidValue.into());
}
let mut r = ReadCursor::new(item);
let _props: u8 = r.read()?;
let value_handle: u16 = r.read()?;
let value_uuid: Uuid = Uuid::from_slice(r.remaining());

if uuid == &value_uuid {
return Ok(Characteristic {
handle: value_handle,
cccd_handle: None,
});
if let AttributeData::Declaration {
props,
handle,
uuid: decl_uuid,
} = AttributeData::decode_declaration(item)?
{
if *uuid == decl_uuid {
// "notify" and "indicate" characteristic properties
let cccd_handle =
if props.any(&[CharacteristicProp::Indicate, CharacteristicProp::Notify]) {
Some(self.get_characteristic_cccd(handle).await?.0)
} else {
None
};

return Ok(Characteristic { handle, cccd_handle });
}

if handle == 0xFFFF {
return Err(Error::NotFound.into());
}
start = handle + 1;
} else {
return Err(Error::InvalidValue.into());
}
}
}
AttRsp::Error { request, handle, code } => return Err(Error::Att(code).into()),
Expand All @@ -269,6 +314,30 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u
}
}

async fn get_characteristic_cccd(&mut self, char_handle: u16) -> Result<(u16, CCCD), BleHostError<T::Error>> {
let data = att::AttReq::ReadByType {
start: char_handle,
end: char_handle + 1,
attribute_type: CHARACTERISTIC_CCCD_UUID16,
};

let pdu = self.request(data).await?;

match AttRsp::decode(pdu.as_ref())? {
AttRsp::ReadByType { mut it } => {
if let Some(Ok((handle, item))) = it.next() {
Ok((handle, CCCD(u16::from_le_bytes(item.try_into().map_err(|_| Error::OutOfMemory)?))))
} else {
Err(Error::NotFound.into())
}
}
AttRsp::Error { request, handle, code } => Err(Error::Att(code).into()),
_ => {
Err(Error::InvalidValue.into())
}
}
}

/// Read a characteristic described by a handle.
///
/// The number of bytes copied into the provided buffer is returned.
Expand Down Expand Up @@ -344,4 +413,91 @@ impl<'reference, 'resources, T: Controller, const MAX: usize, const L2CAP_MTU: u
_ => Err(Error::InvalidValue.into()),
}
}

/// Subscribe to indication/notification of a given Characteristic
///
/// A listener is returned, which has a `next()` method
pub async fn subscribe(
&self,
characteristic: &Characteristic,
indication: bool,
) -> Result<NotificationListener<MAX_NOTIF, NOTIF_QSIZE, ATT_MTU>, BleHostError<T::Error>> {
let properties = u16::to_le_bytes(if indication { 0x02 } else { 0x01 });

let data = att::AttReq::Write {
handle: characteristic.cccd_handle.ok_or(Error::NotSupported)?,
data: &properties,
};

// set the CCCD
let pdu = self.request(data).await?;

match AttRsp::decode(pdu.as_ref())? {
AttRsp::Write => {
// look for a free slot in the n_channel -> handle array
for (n, item) in self.notifications_map.borrow_mut().iter_mut().enumerate() {
if item.is_none() {
item.replace(characteristic.handle);
return Ok(NotificationListener {
listener: self.notifications_channels[n].receiver().into(),
});
}
}
// otherwise, there's no space left in the array
Err(Error::InsufficientSpace.into())
}
AttRsp::Error { request, handle, code } => Err(Error::Att(code).into()),
_ => Err(Error::InvalidValue.into()),
}

}

/// Unsubscribe from a given Characteristic
pub async fn unsubscribe(&self, characteristic: &Characteristic) -> Result<(), BleHostError<T::Error>> {
let mut notifications = self.notifications_map.borrow_mut();
let (item, n) = notifications.iter_mut().enumerate().find_map(|(n, item)| {
if let Some(h) = item {
if *h == characteristic.handle { Some((item, n)) } else { None }
} else {
None
}
}).ok_or(Error::NotFound)?;

// Free up the slot in the n_channel -> handle map
item.take();
// Clear any data queued up in the channel
self.notifications_channels[n].clear();
Ok(())
}

/// Task which handles GATT client actions (needed for notifications to work)
pub async fn task(&self) -> Result<(), BleHostError<T::Error>> {
loop {
let (handle, pdu) = self.rx.receive().await;
let data = pdu.as_ref();

// handle notifications
if data[0] == ATT_HANDLE_VALUE_NTF {
let mut r = ReadCursor::new(&data[1..]);
let value_handle: u16 = r.read()?;
let value_attr = r.remaining();

// let's find the corresponding `n` first, to avoid retaining the borrow_mut() across an await point
let found_n = self.notifications_map.borrow_mut().iter().enumerate().find_map(|(n, item)| {
if let Some(handle) = item {
if *handle == value_handle {
return Some(n)
}
}
None
});

if let Some(n) = found_n {
self.notifications_channels[n]
.send(Vec::from_slice(value_attr).unwrap())
.await;
}
}
}
}
}
6 changes: 4 additions & 2 deletions host/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,10 @@ where

/// Creates a GATT client capable of processing the GATT protocol using the provided table of attributes.
#[cfg(feature = "gatt")]
pub async fn gatt_client<'reference, const MAX: usize, const L2CAP_MTU: usize>(
pub async fn gatt_client<'reference, const MAX_SERVICES: usize, const MAX_NOTIF: usize, const NOTIF_QSIZE: usize, const L2CAP_MTU: usize>(
&'reference self,
connection: &Connection<'reference>,
) -> Result<GattClient<'reference, 'd, T, MAX, L2CAP_MTU>, BleHostError<T::Error>> {
) -> Result<GattClient<'reference, 'd, T, MAX_SERVICES, MAX_NOTIF, NOTIF_QSIZE, L2CAP_MTU>, BleHostError<T::Error>> {
let l2cap = L2capHeader { channel: 4, length: 3 };
let mut buf = [0; 7];
let mut w = WriteCursor::new(&mut buf);
Expand All @@ -780,6 +780,8 @@ where
rx: self.att_inbound.receiver().into(),
ble: self,
connection: connection.clone(),
notifications_map: RefCell::new([const { None }; MAX_NOTIF]),
notifications_channels: [const { Channel::new() }; MAX_NOTIF],
})
}

Expand Down
2 changes: 1 addition & 1 deletion host/tests/gatt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async fn gatt_client_server() {
tokio::time::sleep(Duration::from_secs(5)).await;

println!("[central] creating gatt client");
let mut client = adapter.gatt_client::<10, 128>(&conn).await.unwrap();
let mut client = adapter.gatt_client::<10, 128, 10, 24>(&conn).await.unwrap();

println!("[central] discovering services");
let services = client.services_by_uuid(&SERVICE_UUID).await.unwrap();
Expand Down

0 comments on commit 1e512ab

Please sign in to comment.