Skip to content

Commit

Permalink
Merge pull request #48 from embassy-rs/refcount-conn-channel
Browse files Browse the repository at this point in the history
make connections reference counted
  • Loading branch information
lulf authored May 24, 2024
2 parents a2687b0 + b806ff9 commit e5319ec
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 126 deletions.
70 changes: 41 additions & 29 deletions host/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
//! BLE connection.
use bt_hci::cmd::le::LeConnUpdate;
use bt_hci::cmd::link_control::Disconnect;
use bt_hci::cmd::status::ReadRssi;
use bt_hci::controller::{Controller, ControllerCmdAsync, ControllerCmdSync};
use bt_hci::controller::{ControllerCmdAsync, ControllerCmdSync};
use bt_hci::param::{BdAddr, ConnHandle, DisconnectReason, LeConnRole};
use embassy_time::Duration;

use crate::connection_manager::DynamicConnectionManager;
use crate::host::BleHost;
use crate::scan::ScanConfig;
use crate::BleHostError;

#[derive(Clone)]
pub struct Connection {
handle: ConnHandle,
}

pub struct ConnectConfig<'d> {
pub scan_config: ScanConfig<'d>,
pub connect_params: ConnectParams,
Expand All @@ -40,44 +35,60 @@ impl Default for ConnectParams {
}
}

impl Connection {
pub(crate) fn new(handle: ConnHandle) -> Self {
Self { handle }
pub struct Connection<'d> {
index: u8,
manager: &'d dyn DynamicConnectionManager,
}

impl<'d> Clone for Connection<'d> {
fn clone(&self) -> Self {
self.manager.inc_ref(self.index);
Self {
index: self.index,
manager: self.manager,
}
}
}

/// Connection handle of this connection.
pub fn handle(&self) -> ConnHandle {
self.handle
impl<'d> Drop for Connection<'d> {
fn drop(&mut self) {
self.manager.dec_ref(self.index);
}
}

/// Request disconnection of this connection handle.
pub fn disconnect<T: Controller + ControllerCmdSync<Disconnect>>(
&mut self,
ble: &BleHost<'_, T>,
) -> Result<(), BleHostError<T::Error>> {
ble.connections
.request_disconnect(self.handle, DisconnectReason::RemoteUserTerminatedConn)?;
Ok(())
impl<'d> Connection<'d> {
pub(crate) fn new(index: u8, manager: &'d dyn DynamicConnectionManager) -> Self {
manager.inc_ref(index);
Self { index, manager }
}

/// Connection handle of this connection.
pub fn handle(&self) -> ConnHandle {
self.manager.handle(self.index)
}

/// The connection role for this connection.
pub fn role<T: Controller>(&self, ble: &BleHost<'_, T>) -> Result<LeConnRole, BleHostError<T::Error>> {
let role = ble.connections.role(self.handle)?;
Ok(role)
pub fn role(&self) -> LeConnRole {
self.manager.role(self.index)
}

/// The peer address for this connection.
pub fn peer_address<T: Controller>(&self, ble: &BleHost<'_, T>) -> Result<BdAddr, BleHostError<T::Error>> {
let addr = ble.connections.peer_address(self.handle)?;
Ok(addr)
pub fn peer_address(&self) -> BdAddr {
self.manager.peer_address(self.index)
}

pub fn disconnect(&self) {
self.manager
.disconnect(self.index, DisconnectReason::RemoteUserTerminatedConn);
}

/// The RSSI value for this connection.
pub async fn rssi<T>(&self, ble: &BleHost<'_, T>) -> Result<i8, BleHostError<T::Error>>
where
T: ControllerCmdSync<ReadRssi>,
{
let ret = ble.command(ReadRssi::new(self.handle)).await?;
let handle = self.handle();
let ret = ble.command(ReadRssi::new(handle)).await?;
Ok(ret.rssi)
}

Expand All @@ -90,9 +101,10 @@ impl Connection {
where
T: ControllerCmdAsync<LeConnUpdate>,
{
let handle = self.handle();
match ble
.async_command(LeConnUpdate::new(
self.handle,
handle,
params.min_connection_interval.into(),
params.max_connection_interval.into(),
params.max_latency,
Expand Down
142 changes: 99 additions & 43 deletions host/src/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use core::cell::RefCell;
use core::future::poll_fn;
use core::task::{Context, Poll};

use bt_hci::event::le::LeConnectionComplete;
use bt_hci::param::{AddrKind, BdAddr, ConnHandle, DisconnectReason, LeConnRole};
use embassy_sync::blocking_mutex::raw::NoopRawMutex;
use embassy_sync::signal::Signal;
Expand Down Expand Up @@ -45,42 +44,37 @@ impl<'d> ConnectionManager<'d> {
}
}

pub(crate) fn role(&self, h: ConnHandle) -> Result<LeConnRole, Error> {
let state = self.state.borrow();
for storage in state.connections.iter() {
if storage.state == ConnectionState::Connected && storage.handle.unwrap() == h {
return Ok(storage.role.unwrap());
}
}
Err(Error::NotFound)
pub(crate) fn role(&self, index: u8) -> LeConnRole {
self.with_mut(|state| {
let state = &mut state.connections[index as usize];
assert_eq!(state.state, ConnectionState::Connected);
state.role.unwrap()
})
}

pub(crate) fn peer_address(&self, h: ConnHandle) -> Result<BdAddr, Error> {
let state = self.state.borrow();
for storage in state.connections.iter() {
if storage.state == ConnectionState::Connected && storage.handle.unwrap() == h {
return Ok(storage.peer_addr.unwrap());
}
}
Err(Error::NotFound)
pub(crate) fn handle(&self, index: u8) -> ConnHandle {
self.with_mut(|state| {
let state = &mut state.connections[index as usize];
assert_eq!(state.state, ConnectionState::Connected);
state.handle.unwrap()
})
}

pub(crate) fn request_disconnect(&self, h: ConnHandle, reason: DisconnectReason) -> Result<(), Error> {
let mut state = self.state.borrow_mut();
for storage in state.connections.iter_mut() {
match storage.state {
ConnectionState::Connecting if storage.handle.unwrap() == h => {
storage.state = ConnectionState::Disconnecting(reason);
return Ok(());
}
ConnectionState::Connected if storage.handle.unwrap() == h => {
storage.state = ConnectionState::Disconnecting(reason);
return Ok(());
}
_ => {}
pub(crate) fn peer_address(&self, index: u8) -> BdAddr {
self.with_mut(|state| {
let state = &mut state.connections[index as usize];
assert_eq!(state.state, ConnectionState::Connected);
state.peer_addr.unwrap()
})
}

pub(crate) fn request_disconnect(&self, index: u8, reason: DisconnectReason) {
self.with_mut(|state| {
let state = &mut state.connections[index as usize];
if state.state == ConnectionState::Connected {
state.state = ConnectionState::Disconnecting(reason);
}
}
Err(Error::NotFound)
})
}

pub(crate) fn poll_disconnecting<'m>(&'m self, cx: &mut Context<'_>) -> Poll<DisconnectIter<'m, 'd>> {
Expand Down Expand Up @@ -117,17 +111,23 @@ impl<'d> ConnectionManager<'d> {
Err(Error::NotFound)
}

pub(crate) fn connect(&self, handle: ConnHandle, info: &LeConnectionComplete) -> Result<(), Error> {
pub(crate) fn connect(
&self,
handle: ConnHandle,
peer_addr_kind: AddrKind,
peer_addr: BdAddr,
role: LeConnRole,
) -> Result<(), Error> {
let mut state = self.state.borrow_mut();
let default_credits = state.default_link_credits;
for storage in state.connections.iter_mut() {
if let ConnectionState::Disconnected = storage.state {
if ConnectionState::Disconnected == storage.state && storage.refcount == 0 {
storage.state = ConnectionState::Connecting;
storage.link_credits = default_credits;
storage.handle.replace(handle);
storage.peer_addr_kind.replace(info.peer_addr_kind);
storage.peer_addr.replace(info.peer_addr);
storage.role.replace(info.role);
storage.peer_addr_kind.replace(peer_addr_kind);
storage.peer_addr.replace(peer_addr);
storage.role.replace(role);
state.accept_waker.wake();
return Ok(());
}
Expand All @@ -144,34 +144,64 @@ impl<'d> ConnectionManager<'d> {
self.canceled.signal(());
}

pub(crate) fn poll_accept(&self, peers: &[(AddrKind, &BdAddr)], cx: &mut Context<'_>) -> Poll<ConnHandle> {
pub(crate) fn poll_accept(&self, peers: &[(AddrKind, &BdAddr)], cx: &mut Context<'_>) -> Poll<u8> {
let mut state = self.state.borrow_mut();
state.accept_waker.register(cx.waker());
for storage in state.connections.iter_mut() {
for (idx, storage) in state.connections.iter_mut().enumerate() {
if let ConnectionState::Connecting = storage.state {
let handle = storage.handle.unwrap();
if !peers.is_empty() {
for peer in peers.iter() {
if storage.peer_addr_kind.unwrap() == peer.0 && &storage.peer_addr.unwrap() == peer.1 {
storage.state = ConnectionState::Connected;
return Poll::Ready(handle);
return Poll::Ready(idx as u8);
}
}
} else {
storage.state = ConnectionState::Connected;
return Poll::Ready(handle);
return Poll::Ready(idx as u8);
}
}
}
Poll::Pending
}

fn with_mut<F: FnOnce(&mut State<'d>) -> R, R>(&self, f: F) -> R {
let mut state = self.state.borrow_mut();
f(&mut state)
}

pub(crate) fn log_status(&self) {
let state = self.state.borrow();
state.print();
}

pub(crate) async fn accept(&self, peers: &[(AddrKind, &BdAddr)]) -> ConnHandle {
pub(crate) fn inc_ref(&self, index: u8) {
self.with_mut(|state| {
let state = &mut state.connections[index as usize];
state.refcount = unwrap!(
state.refcount.checked_add(1),
"Too many references to the same connection"
);
});
}

pub(crate) fn dec_ref(&self, index: u8) {
self.with_mut(|state| {
let state = &mut state.connections[index as usize];
state.refcount = unwrap!(
state.refcount.checked_sub(1),
"bug: dropping a connection with refcount 0"
);
if state.refcount == 0 {
if state.state == ConnectionState::Connected {
state.state = ConnectionState::Disconnecting(DisconnectReason::RemoteUserTerminatedConn);
}
}
});
}

pub(crate) async fn accept(&self, peers: &[(AddrKind, &BdAddr)]) -> u8 {
poll_fn(move |cx| self.poll_accept(peers, cx)).await
}

Expand Down Expand Up @@ -234,12 +264,36 @@ impl<'d> ConnectionManager<'d> {
}
}

pub trait DynamicConnectionManager {
pub(crate) trait DynamicConnectionManager {
fn role(&self, index: u8) -> LeConnRole;
fn handle(&self, index: u8) -> ConnHandle;
fn peer_address(&self, index: u8) -> BdAddr;
fn inc_ref(&self, index: u8);
fn dec_ref(&self, index: u8);
fn disconnect(&self, index: u8, reason: DisconnectReason);
fn get_att_mtu(&self, conn: ConnHandle) -> u16;
fn exchange_att_mtu(&self, conn: ConnHandle, mtu: u16) -> u16;
}

impl<'d> DynamicConnectionManager for ConnectionManager<'d> {
fn role(&self, index: u8) -> LeConnRole {
ConnectionManager::role(self, index)
}
fn handle(&self, index: u8) -> ConnHandle {
ConnectionManager::handle(self, index)
}
fn peer_address(&self, index: u8) -> BdAddr {
ConnectionManager::peer_address(self, index)
}
fn inc_ref(&self, index: u8) {
ConnectionManager::inc_ref(self, index)
}
fn dec_ref(&self, index: u8) {
ConnectionManager::dec_ref(self, index)
}
fn disconnect(&self, index: u8, reason: DisconnectReason) {
ConnectionManager::request_disconnect(self, index, reason)
}
fn get_att_mtu(&self, conn: ConnHandle) -> u16 {
let mut state = self.state.borrow_mut();
for storage in state.connections.iter_mut() {
Expand Down Expand Up @@ -297,6 +351,7 @@ pub struct ConnectionStorage {
pub att_mtu: u16,
pub link_credits: usize,
pub link_credit_waker: WakerRegistration,
pub refcount: u8,
}

impl ConnectionStorage {
Expand All @@ -309,6 +364,7 @@ impl ConnectionStorage {
att_mtu: 23,
link_credits: 0,
link_credit_waker: WakerRegistration::new(),
refcount: 0,
};
}

Expand Down
4 changes: 2 additions & 2 deletions host/src/gatt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl<'reference, 'values, 'resources, M: RawMutex, T: Controller, const MAX: usi
pub async fn notify(
&self,
handle: CharacteristicHandle,
connection: &Connection,
connection: &Connection<'_>,
value: &[u8],
) -> Result<(), BleHostError<T::Error>> {
let conn = connection.handle();
Expand Down Expand Up @@ -106,7 +106,7 @@ impl<'reference, 'values, 'resources, M: RawMutex, T: Controller, const MAX: usi
#[derive(Clone)]
pub enum GattEvent<'reference, 'values> {
Write {
connection: &'reference Connection,
connection: &'reference Connection<'reference>,
handle: CharacteristicHandle,
value: &'values [u8],
},
Expand Down
Loading

0 comments on commit e5319ec

Please sign in to comment.