Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using server and client concurrently #193

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions host/src/att.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ pub enum AttReq<'d> {
},
}

#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug)]
pub enum AttRsp<'d> {
ExchangeMtu {
mtu: u16,
Expand All @@ -173,6 +175,11 @@ pub enum AttRsp<'d> {
Write,
}

pub enum Att<'d> {
Req(AttReq<'d>),
Rsp(AttRsp<'d>),
}

impl codec::Type for AttRsp<'_> {
fn size(&self) -> usize {
AttRsp::size(self)
Expand All @@ -191,7 +198,8 @@ impl<'d> codec::Decode<'d> for AttRsp<'d> {
}
}

#[derive(Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Clone, Debug)]
pub struct FindByTypeValueIter<'d> {
cursor: ReadCursor<'d>,
}
Expand All @@ -211,7 +219,8 @@ impl FindByTypeValueIter<'_> {
}
}

#[derive(Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Clone, Debug)]
pub struct ReadByTypeIter<'d> {
item_len: usize,
cursor: ReadCursor<'d>,
Expand Down Expand Up @@ -285,10 +294,14 @@ impl<'d> AttRsp<'d> {
Ok(())
}

pub fn decode(packet: &'d [u8]) -> Result<AttRsp<'d>, codec::Error> {
let mut r = ReadCursor::new(packet);
let kind: u8 = r.read()?;
match kind {
pub fn decode(data: &'d [u8]) -> Result<AttRsp<'d>, codec::Error> {
let mut r = ReadCursor::new(data);
let opcode: u8 = r.read()?;
AttRsp::decode_with_opcode(opcode, r)
}

pub fn decode_with_opcode(opcode: u8, mut r: ReadCursor<'d>) -> Result<AttRsp<'d>, codec::Error> {
match opcode {
ATT_FIND_BY_TYPE_VALUE_RSP => Ok(Self::FindByTypeValue {
it: FindByTypeValueIter { cursor: r },
}),
Expand Down Expand Up @@ -338,8 +351,22 @@ impl codec::Encode for AttReq<'_> {
}

impl<'d> codec::Decode<'d> for AttReq<'d> {
fn decode(src: &'d [u8]) -> Result<Self, codec::Error> {
AttReq::decode(src)
fn decode(data: &'d [u8]) -> Result<AttReq<'d>, codec::Error> {
AttReq::decode(data)
}
}

impl<'d> Att<'d> {
pub fn decode(data: &'d [u8]) -> Result<Att<'d>, codec::Error> {
let mut r = ReadCursor::new(data);
let opcode: u8 = r.read()?;
if opcode % 2 == 0 {
let req = AttReq::decode_with_opcode(opcode, r)?;
Ok(Att::Req(req))
} else {
let rsp = AttRsp::decode_with_opcode(opcode, r)?;
Ok(Att::Rsp(rsp))
}
}
}

Expand Down Expand Up @@ -405,11 +432,15 @@ impl<'d> AttReq<'d> {
}
Ok(())
}
pub fn decode(packet: &'d [u8]) -> Result<AttReq<'d>, codec::Error> {
let mut r = ReadCursor::new(packet);

pub fn decode(data: &'d [u8]) -> Result<AttReq<'d>, codec::Error> {
let mut r = ReadCursor::new(data);
let opcode: u8 = r.read()?;
let payload = r.remaining();
AttReq::decode_with_opcode(opcode, r)
}

pub fn decode_with_opcode(opcode: u8, r: ReadCursor<'d>) -> Result<AttReq<'d>, codec::Error> {
let payload = r.remaining();
match opcode {
ATT_READ_BY_GROUP_TYPE_REQ => {
let start_handle = (payload[0] as u16) + ((payload[1] as u16) << 8);
Expand Down Expand Up @@ -511,7 +542,10 @@ impl<'d> AttReq<'d> {
let offset = (payload[2] as u16) + ((payload[3] as u16) << 8);
Ok(Self::ReadBlob { handle, offset })
}
_ => Err(codec::Error::InvalidValue),
code => {
warn!("[att] unknown opcode {:x}", code);
Err(codec::Error::InvalidValue)
}
}
}
}
2 changes: 2 additions & 0 deletions host/src/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ impl<'d> WriteCursor<'d> {
}

#[derive(Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug)]
pub struct ReadCursor<'d> {
pos: usize,
data: &'d [u8],
Expand Down
10 changes: 6 additions & 4 deletions host/src/gatt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<'reference, 'values, C: Controller, M: RawMutex, const MAX: usize, const L2
Self {
stack,
server: AttributeServer::new(table),
rx: stack.host.att_inbound.receiver().into(),
rx: stack.host.att_server.receiver().into(),
tx: stack.host.outbound.sender().into(),
connections: &stack.host.connections,
}
Expand Down Expand Up @@ -277,7 +277,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz

Ok(Self {
known_services: RefCell::new(heapless::Vec::new()),
rx: stack.host.att_inbound.receiver().into(),
rx: stack.host.att_client.receiver().into(),
stack,
connection: connection.clone(),

Expand All @@ -304,7 +304,8 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz
};

let pdu = self.request(data).await?;
match AttRsp::decode(pdu.as_ref())? {
let res = AttRsp::decode(pdu.as_ref())?;
match res {
AttRsp::Error { request, handle, code } => {
if code == att::AttErrorCode::AttributeNotFound {
break;
Expand Down Expand Up @@ -332,7 +333,8 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz
}
start = end + 1;
}
_ => {
res => {
trace!("[gatt client] response: {:?}", res);
return Err(Error::InvalidValue.into());
}
}
Expand Down
47 changes: 31 additions & 16 deletions host/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ pub(crate) struct BleHost<'d, T> {
pub(crate) reassembly: PacketReassembly<'d>,
pub(crate) channels: ChannelManager<'d, { config::L2CAP_RX_QUEUE_SIZE }>,
#[cfg(feature = "gatt")]
pub(crate) att_inbound: Channel<NoopRawMutex, (ConnHandle, Pdu<'d>), { config::L2CAP_RX_QUEUE_SIZE }>,
pub(crate) att_server: Channel<NoopRawMutex, (ConnHandle, Pdu<'d>), { config::L2CAP_RX_QUEUE_SIZE }>,
#[cfg(feature = "gatt")]
pub(crate) att_client: Channel<NoopRawMutex, (ConnHandle, Pdu<'d>), { config::L2CAP_RX_QUEUE_SIZE }>,
pub(crate) rx_pool: &'d dyn GlobalPacketPool<'d>,
pub(crate) outbound: Channel<NoopRawMutex, (ConnHandle, Pdu<'d>), { config::L2CAP_TX_QUEUE_SIZE }>,

Expand Down Expand Up @@ -206,7 +208,9 @@ where
channels: ChannelManager::new(rx_pool, channels, channels_rx),
rx_pool,
#[cfg(feature = "gatt")]
att_inbound: Channel::new(),
att_server: Channel::new(),
#[cfg(feature = "gatt")]
att_client: Channel::new(),
#[cfg(feature = "scan")]
scanner: Channel::new(),
advertise_state: AdvState::new(advertise_handles),
Expand Down Expand Up @@ -337,9 +341,8 @@ where
L2CAP_CID_ATT => {
// Handle ATT MTU exchange here since it doesn't strictly require
// gatt to be enabled.
if let Ok(att::AttReq::ExchangeMtu { mtu }) =
att::AttReq::decode(&packet.as_ref()[..header.length as usize])
{
let a = att::Att::decode(&packet.as_ref()[..header.length as usize]);
if let Ok(att::Att::Req(att::AttReq::ExchangeMtu { mtu })) = a {
let mtu = self.connections.exchange_att_mtu(acl.handle(), mtu);

let rsp = att::AttRsp::ExchangeMtu { mtu };
Expand All @@ -352,25 +355,37 @@ where
w.write_hci(&l2cap)?;
w.write(rsp)?;

trace!("[host] agreed att MTU of {}", mtu);
info!("[host] agreed att MTU of {}", mtu);
let len = w.len();
if let Err(e) = self.outbound.try_send((acl.handle(), Pdu::new(packet, len))) {
return Err(Error::OutOfMemory);
}
} else if let Ok(att::AttRsp::ExchangeMtu { mtu }) =
att::AttRsp::decode(&packet.as_ref()[..header.length as usize])
{
trace!("[host] remote agreed att MTU of {}", mtu);
} else if let Ok(att::Att::Rsp(att::AttRsp::ExchangeMtu { mtu })) = a {
info!("[host] remote agreed att MTU of {}", mtu);
self.connections.exchange_att_mtu(acl.handle(), mtu);
} else {
#[cfg(feature = "gatt")]
if let Err(e) = self
.att_inbound
.try_send((acl.handle(), Pdu::new(packet, header.length as usize)))
{
return Err(Error::OutOfMemory);
match a {
Ok(att::Att::Req(_)) => {
if let Err(e) = self
.att_server
.try_send((acl.handle(), Pdu::new(packet, header.length as usize)))
{
return Err(Error::OutOfMemory);
}
}
Ok(att::Att::Rsp(_)) => {
if let Err(e) = self
.att_client
.try_send((acl.handle(), Pdu::new(packet, header.length as usize)))
{
return Err(Error::OutOfMemory);
}
}
Err(e) => {
warn!("Error decoding attribute payload: {:?}", e);
}
}

#[cfg(not(feature = "gatt"))]
return Err(Error::NotSupported);
}
Expand Down
Loading