diff --git a/host/src/att.rs b/host/src/att.rs index 24f7dbf5..7e38ec0c 100644 --- a/host/src/att.rs +++ b/host/src/att.rs @@ -152,6 +152,8 @@ pub enum AttReq<'d> { }, } +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug)] pub enum AttRsp<'d> { ExchangeMtu { mtu: u16, @@ -173,6 +175,11 @@ pub enum AttRsp<'d> { Write, } +pub enum Att<'d> { + Req(AttReq<'d>), + Rsp(AttRsp<'d>), +} + impl codec::Type for AttRsp<'_> { fn size(&self) -> usize { AttRsp::size(self) @@ -191,7 +198,8 @@ impl<'d> codec::Decode<'d> for AttRsp<'d> { } } -#[derive(Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Clone, Debug)] pub struct FindByTypeValueIter<'d> { cursor: ReadCursor<'d>, } @@ -211,7 +219,8 @@ impl FindByTypeValueIter<'_> { } } -#[derive(Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Clone, Debug)] pub struct ReadByTypeIter<'d> { item_len: usize, cursor: ReadCursor<'d>, @@ -285,10 +294,14 @@ impl<'d> AttRsp<'d> { Ok(()) } - pub fn decode(packet: &'d [u8]) -> Result, codec::Error> { - let mut r = ReadCursor::new(packet); - let kind: u8 = r.read()?; - match kind { + pub fn decode(data: &'d [u8]) -> Result, codec::Error> { + let mut r = ReadCursor::new(data); + let opcode: u8 = r.read()?; + AttRsp::decode_with_opcode(opcode, r) + } + + pub fn decode_with_opcode(opcode: u8, mut r: ReadCursor<'d>) -> Result, codec::Error> { + match opcode { ATT_FIND_BY_TYPE_VALUE_RSP => Ok(Self::FindByTypeValue { it: FindByTypeValueIter { cursor: r }, }), @@ -338,8 +351,22 @@ impl codec::Encode for AttReq<'_> { } impl<'d> codec::Decode<'d> for AttReq<'d> { - fn decode(src: &'d [u8]) -> Result { - AttReq::decode(src) + fn decode(data: &'d [u8]) -> Result, codec::Error> { + AttReq::decode(data) + } +} + +impl<'d> Att<'d> { + pub fn decode(data: &'d [u8]) -> Result, codec::Error> { + let mut r = ReadCursor::new(data); + let opcode: u8 = r.read()?; + if opcode % 2 == 0 { + let req = AttReq::decode_with_opcode(opcode, r)?; + Ok(Att::Req(req)) + } else { + let rsp = AttRsp::decode_with_opcode(opcode, r)?; + Ok(Att::Rsp(rsp)) + } } } @@ -405,11 +432,15 @@ impl<'d> AttReq<'d> { } Ok(()) } - pub fn decode(packet: &'d [u8]) -> Result, codec::Error> { - let mut r = ReadCursor::new(packet); + + pub fn decode(data: &'d [u8]) -> Result, codec::Error> { + let mut r = ReadCursor::new(data); let opcode: u8 = r.read()?; - let payload = r.remaining(); + AttReq::decode_with_opcode(opcode, r) + } + pub fn decode_with_opcode(opcode: u8, r: ReadCursor<'d>) -> Result, codec::Error> { + let payload = r.remaining(); match opcode { ATT_READ_BY_GROUP_TYPE_REQ => { let start_handle = (payload[0] as u16) + ((payload[1] as u16) << 8); @@ -511,7 +542,10 @@ impl<'d> AttReq<'d> { let offset = (payload[2] as u16) + ((payload[3] as u16) << 8); Ok(Self::ReadBlob { handle, offset }) } - _ => Err(codec::Error::InvalidValue), + code => { + warn!("[att] unknown opcode {:x}", code); + Err(codec::Error::InvalidValue) + } } } } diff --git a/host/src/cursor.rs b/host/src/cursor.rs index 47af0660..cfa569a1 100644 --- a/host/src/cursor.rs +++ b/host/src/cursor.rs @@ -109,6 +109,8 @@ impl<'d> WriteCursor<'d> { } #[derive(Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug)] pub struct ReadCursor<'d> { pos: usize, data: &'d [u8], diff --git a/host/src/gatt.rs b/host/src/gatt.rs index f55b4dcf..31484080 100644 --- a/host/src/gatt.rs +++ b/host/src/gatt.rs @@ -43,7 +43,7 @@ impl<'reference, 'values, C: Controller, M: RawMutex, const MAX: usize, const L2 Self { stack, server: AttributeServer::new(table), - rx: stack.host.att_inbound.receiver().into(), + rx: stack.host.att_server.receiver().into(), tx: stack.host.outbound.sender().into(), connections: &stack.host.connections, } @@ -277,7 +277,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz Ok(Self { known_services: RefCell::new(heapless::Vec::new()), - rx: stack.host.att_inbound.receiver().into(), + rx: stack.host.att_client.receiver().into(), stack, connection: connection.clone(), @@ -304,7 +304,8 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz }; let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + let res = AttRsp::decode(pdu.as_ref())?; + match res { AttRsp::Error { request, handle, code } => { if code == att::AttErrorCode::AttributeNotFound { break; @@ -332,7 +333,8 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz } start = end + 1; } - _ => { + res => { + trace!("[gatt client] response: {:?}", res); return Err(Error::InvalidValue.into()); } } diff --git a/host/src/host.rs b/host/src/host.rs index b756933d..ed68dc48 100644 --- a/host/src/host.rs +++ b/host/src/host.rs @@ -62,7 +62,9 @@ pub(crate) struct BleHost<'d, T> { pub(crate) reassembly: PacketReassembly<'d>, pub(crate) channels: ChannelManager<'d, { config::L2CAP_RX_QUEUE_SIZE }>, #[cfg(feature = "gatt")] - pub(crate) att_inbound: Channel), { config::L2CAP_RX_QUEUE_SIZE }>, + pub(crate) att_server: Channel), { config::L2CAP_RX_QUEUE_SIZE }>, + #[cfg(feature = "gatt")] + pub(crate) att_client: Channel), { config::L2CAP_RX_QUEUE_SIZE }>, pub(crate) rx_pool: &'d dyn GlobalPacketPool<'d>, pub(crate) outbound: Channel), { config::L2CAP_TX_QUEUE_SIZE }>, @@ -206,7 +208,9 @@ where channels: ChannelManager::new(rx_pool, channels, channels_rx), rx_pool, #[cfg(feature = "gatt")] - att_inbound: Channel::new(), + att_server: Channel::new(), + #[cfg(feature = "gatt")] + att_client: Channel::new(), #[cfg(feature = "scan")] scanner: Channel::new(), advertise_state: AdvState::new(advertise_handles), @@ -337,9 +341,8 @@ where L2CAP_CID_ATT => { // Handle ATT MTU exchange here since it doesn't strictly require // gatt to be enabled. - if let Ok(att::AttReq::ExchangeMtu { mtu }) = - att::AttReq::decode(&packet.as_ref()[..header.length as usize]) - { + let a = att::Att::decode(&packet.as_ref()[..header.length as usize]); + if let Ok(att::Att::Req(att::AttReq::ExchangeMtu { mtu })) = a { let mtu = self.connections.exchange_att_mtu(acl.handle(), mtu); let rsp = att::AttRsp::ExchangeMtu { mtu }; @@ -352,25 +355,37 @@ where w.write_hci(&l2cap)?; w.write(rsp)?; - trace!("[host] agreed att MTU of {}", mtu); + info!("[host] agreed att MTU of {}", mtu); let len = w.len(); if let Err(e) = self.outbound.try_send((acl.handle(), Pdu::new(packet, len))) { return Err(Error::OutOfMemory); } - } else if let Ok(att::AttRsp::ExchangeMtu { mtu }) = - att::AttRsp::decode(&packet.as_ref()[..header.length as usize]) - { - trace!("[host] remote agreed att MTU of {}", mtu); + } else if let Ok(att::Att::Rsp(att::AttRsp::ExchangeMtu { mtu })) = a { + info!("[host] remote agreed att MTU of {}", mtu); self.connections.exchange_att_mtu(acl.handle(), mtu); } else { #[cfg(feature = "gatt")] - if let Err(e) = self - .att_inbound - .try_send((acl.handle(), Pdu::new(packet, header.length as usize))) - { - return Err(Error::OutOfMemory); + match a { + Ok(att::Att::Req(_)) => { + if let Err(e) = self + .att_server + .try_send((acl.handle(), Pdu::new(packet, header.length as usize))) + { + return Err(Error::OutOfMemory); + } + } + Ok(att::Att::Rsp(_)) => { + if let Err(e) = self + .att_client + .try_send((acl.handle(), Pdu::new(packet, header.length as usize))) + { + return Err(Error::OutOfMemory); + } + } + Err(e) => { + warn!("Error decoding attribute payload: {:?}", e); + } } - #[cfg(not(feature = "gatt"))] return Err(Error::NotSupported); }