diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1c3d25c2..cc67a468 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -39,12 +39,36 @@ use rs_matter::persist::Psm; use rs_matter::respond::DefaultResponder; use rs_matter::secure_channel::spake2p::VerifierData; use rs_matter::transport::core::MATTER_SOCKET_BIND_ADDR; -use rs_matter::utils::buf::PooledBuffers; +use rs_matter::utils::init::InitMaybeUninit; use rs_matter::utils::select::Coalesce; +use rs_matter::utils::storage::pooled::PooledBuffers; use rs_matter::MATTER_PORT; +use static_cell::StaticCell; mod dev_att; +static DEV_DET: BasicInfoConfig = BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8000, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1", + serial_no: "aabbccdd", + device_name: "OnOff Light", + product_name: "Light123", + vendor_name: "Vendor PQR", +}; + +static DEV_ATT: dev_att::HardCodedDevAtt = dev_att::HardCodedDevAtt::new(); + +static MATTER: StaticCell = StaticCell::new(); + +static BUFFERS: StaticCell> = StaticCell::new(); + +static SUBSCRIPTIONS: StaticCell> = StaticCell::new(); + +static PSM: StaticCell> = StaticCell::new(); + fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() // Increase the stack size until the example can work without stack blowups. @@ -54,7 +78,7 @@ fn main() -> Result<(), Error> { // e.g., an opt-level of "0" will require a several times' larger stack. // // Optimizing/lowering `rs-matter` memory consumption is an ongoing topic. - .stack_size(95 * 1024) + .stack_size(65 * 1024) .spawn(run) .unwrap(); @@ -72,36 +96,22 @@ fn run() -> Result<(), Error> { core::mem::size_of::>() ); - let dev_det = BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8000, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1", - serial_no: "aabbccdd", - device_name: "OnOff Light", - product_name: "Light123", - vendor_name: "Vendor PQR", - }; - - let dev_att = dev_att::HardCodedDevAtt::new(); - - let matter = Matter::new( - &dev_det, - &dev_att, + let matter = MATTER.uninit().init_with(Matter::init( + &DEV_DET, + &DEV_ATT, // NOTE: // For `no_std` environments, provide your own epoch and rand functions here MdnsService::Builtin, rs_matter::utils::epoch::sys_epoch, rs_matter::utils::rand::sys_rand, MATTER_PORT, - ); + )); matter.initialize_transport_buffers()?; info!("Matter initialized"); - let buffers = PooledBuffers::<10, NoopRawMutex, _>::new(0); + let buffers = BUFFERS.uninit().init_with(PooledBuffers::init(0)); info!("IM buffers initialized"); @@ -109,14 +119,14 @@ fn run() -> Result<(), Error> { let on_off = cluster_on_off::OnOffCluster::new(Dataver::new_rand(matter.rand())); - let subscriptions = Subscriptions::<3>::new(); + let subscriptions = SUBSCRIPTIONS.uninit().init_with(Subscriptions::init()); // Assemble our Data Model handler by composing the predefined Root Endpoint handler with our custom On/Off clusters let dm_handler = HandlerCompat(dm_handler(&matter, &on_off)); // Create a default responder capable of handling up to 3 subscriptions // All other subscription requests will be turned down with "resource exhausted" - let responder = DefaultResponder::new(&matter, &buffers, &subscriptions, dm_handler); + let responder = DefaultResponder::new(&matter, buffers, &subscriptions, dm_handler); info!( "Responder memory: Responder={}B, Runner={}B", core::mem::size_of_val(&responder), @@ -161,8 +171,10 @@ fn run() -> Result<(), Error> { // NOTE: // Replace with your own persister for e.g. `no_std` environments - let mut psm = Psm::new(&matter, std::env::temp_dir().join("rs-matter"))?; - let mut persist = pin!(psm.run()); + + let psm = PSM.uninit().init_with(Psm::init()); + + let mut persist = pin!(psm.run(std::env::temp_dir().join("rs-matter"), &matter)); // Combine all async tasks in a single one let all = select4( diff --git a/examples/onoff_light_bt/src/comm.rs b/examples/onoff_light_bt/src/comm.rs index 6052b485..3c1f2135 100644 --- a/examples/onoff_light_bt/src/comm.rs +++ b/examples/onoff_light_bt/src/comm.rs @@ -32,7 +32,7 @@ use rs_matter::interaction_model::core::IMStatusCode; use rs_matter::interaction_model::messages::ib::Status; use rs_matter::tlv::{FromTLV, OctetStr, TLVElement}; use rs_matter::transport::exchange::Exchange; -use rs_matter::utils::notification::Notification; +use rs_matter::utils::sync::Notification; /// A _fake_ cluster implementing the Matter Network Commissioning Cluster /// for managing WiFi networks. diff --git a/examples/onoff_light_bt/src/main.rs b/examples/onoff_light_bt/src/main.rs index 4d6725ab..8099eb28 100644 --- a/examples/onoff_light_bt/src/main.rs +++ b/examples/onoff_light_bt/src/main.rs @@ -62,10 +62,9 @@ use rs_matter::respond::DefaultResponder; use rs_matter::secure_channel::spake2p::VerifierData; use rs_matter::transport::core::MATTER_SOCKET_BIND_ADDR; use rs_matter::transport::network::btp::{Btp, BtpContext}; -use rs_matter::utils::buf::PooledBuffers; -use rs_matter::utils::notification::Notification; use rs_matter::utils::select::Coalesce; -use rs_matter::utils::std_mutex::StdRawMutex; +use rs_matter::utils::storage::pooled::PooledBuffers; +use rs_matter::utils::sync::{blocking::raw::StdRawMutex, Notification}; use rs_matter::MATTER_PORT; mod comm; @@ -182,8 +181,8 @@ fn run() -> Result<(), Error> { // NOTE: // Replace with your own persister for e.g. `no_std` environments - let mut psm = Psm::new(&matter, std::env::temp_dir().join("rs-matter"))?; - let mut persist = pin!(psm.run()); + let mut psm: Psm<4096> = Psm::new(); + let mut persist = pin!(psm.run(std::env::temp_dir().join("rs-matter"), &matter)); if !matter.is_commissioned() { // Not commissioned yet, start commissioning first diff --git a/rs-matter/Cargo.toml b/rs-matter/Cargo.toml index 5f46190c..61548f9d 100644 --- a/rs-matter/Cargo.toml +++ b/rs-matter/Cargo.toml @@ -44,7 +44,8 @@ critical-section = "1.1" domain = { version = "0.10", default-features = false, features = ["heapless"] } portable-atomic = "1" qrcodegen-no-heap = "1.8" -scopeguard = "1" +scopeguard = { version = "1", default-features = false } +pinned-init = { version = "0.0.8", default-features = false } # crypto openssl = { version = "0.10", optional = true } @@ -81,8 +82,9 @@ tokio-stream = { version = "0.1" } [dev-dependencies] env_logger = "0.11" nix = { version = "0.27", features = ["net"] } -futures-lite = "1" +futures-lite = "2" async-channel = "2" +static_cell = "2" [[example]] name = "onoff_light" diff --git a/rs-matter/src/acl.rs b/rs-matter/src/acl.rs index b5715f65..5289b157 100644 --- a/rs-matter/src/acl.rs +++ b/rs-matter/src/acl.rs @@ -15,20 +15,22 @@ * limitations under the License. */ -use core::{cell::RefCell, fmt::Display, num::NonZeroU8}; - -use crate::{ - data_model::objects::{Access, ClusterId, EndptId, Privilege}, - error::{Error, ErrorCode}, - fabric, - interaction_model::messages::GenericPath, - tlv::{self, FromTLV, Nullable, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, - transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, - utils::writebuf::WriteBuf, -}; +use core::{fmt::Display, num::NonZeroU8}; + use log::error; + use num_derive::FromPrimitive; +use crate::data_model::objects::{Access, ClusterId, EndptId, Privilege}; +use crate::error::{Error, ErrorCode}; +use crate::fabric; +use crate::interaction_model::messages::GenericPath; +use crate::tlv::{self, FromTLV, Nullable, TLVElement, TLVList, TLVWriter, TagType, ToTLV}; +use crate::transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; +use crate::utils::storage::WriteBuf; + // Matter Minimum Requirements pub const SUBJECTS_PER_ENTRY: usize = 4; pub const TARGETS_PER_ENTRY: usize = 3; @@ -301,15 +303,17 @@ pub struct AclEntry { subjects: Subjects, targets: Targets, // TODO: Instead of the direct value, we should consider GlobalElements::FabricIndex + // Note that this field will always be `Some(NN)` when the entry is persisted in storage, + // however, it will be `None` when the entry is coming from the other peer #[tagval(0xFE)] - pub fab_idx: NonZeroU8, + pub fab_idx: Option, } impl AclEntry { pub fn new(fab_idx: NonZeroU8, privilege: Privilege, auth_mode: AuthMode) -> Self { const INIT_SUBJECTS: Option = None; Self { - fab_idx, + fab_idx: Some(fab_idx), privilege, auth_mode, subjects: [INIT_SUBJECTS; SUBJECTS_PER_ENTRY], @@ -368,7 +372,11 @@ impl AclEntry { } // true if both are true - allow && self.fab_idx.get() == accessor.fab_idx + allow + && self + .fab_idx + .map(|fab_idx| fab_idx.get() == accessor.fab_idx) + .unwrap_or(false) } fn match_access_desc(&self, object: &AccessDesc) -> bool { @@ -411,7 +419,7 @@ impl AclEntry { const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; -type AclEntries = heapless::Vec, MAX_ACL_ENTRIES>; +type AclEntries = crate::utils::storage::Vec, MAX_ACL_ENTRIES>; pub struct AclMgr { entries: AclEntries, @@ -425,6 +433,7 @@ impl Default for AclMgr { } impl AclMgr { + /// Create a new ACL Manager #[inline(always)] pub const fn new() -> Self { Self { @@ -433,6 +442,14 @@ impl AclMgr { } } + /// Return an in-place initializer for ACL Manager + pub fn init() -> impl Init { + init!(Self { + entries <- AclEntries::init(), + changed: false, + }) + } + pub fn erase_all(&mut self) -> Result<(), Error> { self.entries.clear(); self.changed = true; @@ -441,13 +458,18 @@ impl AclMgr { } pub fn add(&mut self, entry: AclEntry) -> Result { + let Some(fab_idx) = entry.fab_idx else { + // When persisting entries, the `fab_idx` should always be set + return Err(ErrorCode::Invalid.into()); + }; + if entry.auth_mode == AuthMode::Pase { // Reserved for future use // TODO: Should be something that results in IMStatusCode::ConstraintError Err(ErrorCode::Invalid)?; } - let cnt = self.get_index_in_fabric(MAX_ACL_ENTRIES, entry.fab_idx); + let cnt = self.get_index_in_fabric(MAX_ACL_ENTRIES, fab_idx); if cnt >= ENTRIES_PER_FABRIC as u8 { Err(ErrorCode::NoSpace)?; } @@ -455,8 +477,6 @@ impl AclMgr { let slot = self.entries.iter().position(|a| a.is_none()); if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES { - let fab_idx = entry.fab_idx; - let slot = if let Some(slot) = slot { self.entries[slot] = Some(entry); @@ -501,7 +521,7 @@ impl AclMgr { for entry in &mut self.entries { if entry .as_ref() - .map(|e| e.fab_idx == fab_idx) + .map(|e| e.fab_idx == Some(fab_idx)) .unwrap_or(false) { *entry = None; @@ -565,7 +585,7 @@ impl AclMgr { pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - tlv::from_tlv(&mut self.entries, &root)?; + tlv::vec_from_tlv(&mut self.entries, &root)?; self.changed = false; Ok(()) @@ -606,7 +626,11 @@ impl AclMgr { for (curr_index, entry) in self .entries .iter_mut() - .filter(|e| e.as_ref().filter(|e1| e1.fab_idx == fab_idx).is_some()) + .filter(|e| { + e.as_ref() + .filter(|e1| e1.fab_idx == Some(fab_idx)) + .is_some() + }) .enumerate() { if curr_index == index as usize { @@ -625,7 +649,7 @@ impl AclMgr { .iter() .take(till_slot_index) .flatten() - .filter(|e| e.fab_idx == fab_idx) + .filter(|e| e.fab_idx == Some(fab_idx)) .count() as u8 } } @@ -643,7 +667,7 @@ impl core::fmt::Display for AclMgr { #[cfg(test)] #[allow(clippy::bool_assert_comparison)] pub(crate) mod tests { - use core::{cell::RefCell, num::NonZeroU8}; + use core::num::NonZeroU8; use crate::{ acl::{gen_noc_cat, AccessorSubjects}, @@ -651,6 +675,8 @@ pub(crate) mod tests { interaction_model::messages::GenericPath, }; + use crate::utils::cell::RefCell; + use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target}; pub(crate) const FAB_1: NonZeroU8 = match NonZeroU8::new(1) { diff --git a/rs-matter/src/cert/mod.rs b/rs-matter/src/cert/mod.rs index c5c158a5..0bbf00fe 100644 --- a/rs-matter/src/cert/mod.rs +++ b/rs-matter/src/cert/mod.rs @@ -21,7 +21,7 @@ use crate::{ crypto::KeyPair, error::{Error, ErrorCode}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, - utils::{epoch::MATTER_CERT_DOESNT_EXPIRE, writebuf::WriteBuf}, + utils::{epoch::MATTER_CERT_DOESNT_EXPIRE, storage::WriteBuf}, }; use log::error; use num_derive::FromPrimitive; @@ -857,7 +857,7 @@ mod tests { use crate::cert::Cert; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; - use crate::utils::writebuf::WriteBuf; + use crate::utils::storage::WriteBuf; #[test] fn test_asn1_encode_success() { diff --git a/rs-matter/src/core.rs b/rs-matter/src/core.rs index a52c18ac..c68db6ce 100644 --- a/rs-matter/src/core.rs +++ b/rs-matter/src/core.rs @@ -15,27 +15,26 @@ * limitations under the License. */ -use core::cell::RefCell; - use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use crate::{ - acl::AclMgr, - data_model::{ - cluster_basic_information::BasicInfoConfig, - sdm::{dev_att::DevAttDataFetcher, failsafe::FailSafe}, - }, - error::*, - fabric::FabricMgr, - mdns::MdnsService, - pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, - secure_channel::{pake::PaseMgr, spake2p::VerifierData}, - transport::{ - core::{PacketBufferExternalAccess, TransportMgr}, - network::{NetworkReceive, NetworkSend}, - }, - utils::{buf::BufferAccess, epoch::Epoch, notification::Notification, rand::Rand}, +use crate::acl::AclMgr; +use crate::data_model::{ + cluster_basic_information::BasicInfoConfig, + sdm::{dev_att::DevAttDataFetcher, failsafe::FailSafe}, }; +use crate::error::*; +use crate::fabric::FabricMgr; +use crate::mdns::MdnsService; +use crate::pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}; +use crate::secure_channel::{pake::PaseMgr, spake2p::VerifierData}; +use crate::transport::core::{PacketBufferExternalAccess, TransportMgr}; +use crate::transport::network::{NetworkReceive, NetworkSend}; +use crate::utils::cell::RefCell; +use crate::utils::epoch::Epoch; +use crate::utils::init::{init, Init}; +use crate::utils::rand::Rand; +use crate::utils::storage::pooled::BufferAccess; +use crate::utils::sync::Notification; /* The Matter Port */ pub const MATTER_PORT: u16 = 5540; @@ -65,6 +64,16 @@ pub struct Matter<'a> { } impl<'a> Matter<'a> { + /// Create a new Matter object when support for the Rust Standard Library is enabled. + /// + /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. + /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device + /// requires a set of device attestation certificates and keys. It is the responsibility of + /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * port: The port number on which the Matter stack will listen for incoming connections. #[cfg(feature = "std")] #[inline(always)] pub const fn new_default( @@ -79,12 +88,19 @@ impl<'a> Matter<'a> { Self::new(dev_det, dev_att, mdns, sys_epoch, sys_rand, port) } - /// Creates a new Matter object + /// Create a new Matter object /// /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device /// requires a set of device attestation certificates and keys. It is the responsibility of /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * epoch: A function of type [Epoch]. This function is responsible for providing the current + /// "unix" time in milliseconds + /// * rand: A function of type [Rand]. This function is responsible for generating random data. + /// * port: The port number on which the Matter stack will listen for incoming connections. #[inline(always)] pub const fn new( dev_det: &'a BasicInfoConfig<'a>, @@ -99,7 +115,7 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), - transport_mgr: TransportMgr::new(mdns.new_impl(dev_det, port), epoch, rand), + transport_mgr: TransportMgr::new(mdns, dev_det, port, epoch, rand), persist_notification: Notification::new(), epoch, rand, @@ -109,6 +125,68 @@ impl<'a> Matter<'a> { } } + /// Create an in-place initializer for a Matter object + /// when support for the Rust Standard Library is enabled. + /// + /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. + /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device + /// requires a set of device attestation certificates and keys. It is the responsibility of + /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * port: The port number on which the Matter stack will listen for incoming connections. + #[cfg(feature = "std")] + pub fn init_default( + dev_det: &'a BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + mdns: MdnsService<'a>, + port: u16, + ) -> impl Init { + use crate::utils::epoch::sys_epoch; + use crate::utils::rand::sys_rand; + + Self::init(dev_det, dev_att, mdns, sys_epoch, sys_rand, port) + } + + /// Create an in-place initializer for a Matter object + /// + /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. + /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device + /// requires a set of device attestation certificates and keys. It is the responsibility of + /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * epoch: A function of type [Epoch]. This function is responsible for providing the current + /// "unix" time in milliseconds + /// * rand: A function of type [Rand]. This function is responsible for generating random data. + /// * port: The port number on which the Matter stack will listen for incoming connections. + pub fn init( + dev_det: &'a BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + mdns: MdnsService<'a>, + epoch: Epoch, + rand: Rand, + port: u16, + ) -> impl Init { + init!( + Self { + fabric_mgr <- RefCell::init(FabricMgr::init()), + acl_mgr <- RefCell::init(AclMgr::init()), + pase_mgr <- RefCell::init(PaseMgr::init(epoch, rand)), + failsafe: RefCell::new(FailSafe::new()), + transport_mgr <- TransportMgr::init(mdns, dev_det, port, epoch, rand), + persist_notification: Notification::new(), + epoch, + rand, + dev_det, + dev_att, + port, + } + ) + } + pub fn initialize_transport_buffers(&self) -> Result<(), Error> { self.transport_mgr.initialize_buffers() } @@ -155,8 +233,7 @@ impl<'a> Matter<'a> { /// after that - while/if we still have exclusive, mutable access to the `Matter` object - /// replace the `MdnsService::Disabled` initial impl with another, like `MdnsService::Provided`. pub fn replace_mdns(&mut self, mdns: MdnsService<'a>) { - self.transport_mgr - .replace_mdns(mdns.new_impl(self.dev_det, self.port)); + self.transport_mgr.replace_mdns(mdns); } /// A utility method to replace the initial Device Attestation Data Fetcher with another one. diff --git a/rs-matter/src/data_model/core.rs b/rs-matter/src/data_model/core.rs index a3ebeb88..e17115f3 100644 --- a/rs-matter/src/data_model/core.rs +++ b/rs-matter/src/data_model/core.rs @@ -26,7 +26,7 @@ use embassy_time::{Instant, Timer}; use log::{debug, error, info, warn}; use crate::interaction_model::messages::ib::AttrStatus; -use crate::utils::buf::BufferAccess; +use crate::utils::storage::pooled::BufferAccess; use crate::{error::*, Matter}; use crate::interaction_model::core::{ @@ -39,7 +39,7 @@ use crate::interaction_model::messages::msg::{ use crate::respond::ExchangeHandler; use crate::tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType}; use crate::transport::exchange::{Exchange, MAX_EXCHANGE_RX_BUF_SIZE, MAX_EXCHANGE_TX_BUF_SIZE}; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; use super::objects::*; use super::subscriptions::Subscriptions; diff --git a/rs-matter/src/data_model/sdm/noc.rs b/rs-matter/src/data_model/sdm/noc.rs index 64e44f30..52793766 100644 --- a/rs-matter/src/data_model/sdm/noc.rs +++ b/rs-matter/src/data_model/sdm/noc.rs @@ -32,7 +32,7 @@ use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfSt use crate::transport::exchange::Exchange; use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; use crate::{attribute_enum, cmd_enter, command_enum, error::*}; use super::dev_att::{DataType, DevAttDataFetcher}; @@ -157,14 +157,14 @@ pub const CLUSTER: Cluster<'static> = Cluster { pub struct NocData { pub key_pair: KeyPair, - pub root_ca: heapless::Vec, + pub root_ca: crate::utils::storage::Vec, } impl NocData { pub fn new(key_pair: KeyPair) -> Self { Self { key_pair, - root_ca: heapless::Vec::new(), + root_ca: crate::utils::storage::Vec::new(), } } } @@ -327,15 +327,16 @@ impl NocCluster { let noc_cert = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; info!("Received NOC as: {}", noc_cert); - let noc = heapless::Vec::from_slice(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; + let noc = crate::utils::storage::Vec::from_slice(r.noc_value.0) + .map_err(|_| NocStatus::InvalidNOC)?; let icac = if let Some(icac_value) = r.icac_value { if !icac_value.0.is_empty() { let icac_cert = Cert::new(icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; info!("Received ICAC as: {}", icac_cert); - let icac = - heapless::Vec::from_slice(icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; + let icac = crate::utils::storage::Vec::from_slice(icac_value.0) + .map_err(|_| NocStatus::InvalidNOC)?; Some(icac) } else { None @@ -661,8 +662,8 @@ impl NocCluster { let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); - noc_data.root_ca = - heapless::Vec::from_slice(req.str.0).map_err(|_| ErrorCode::BufferTooSmall)?; + noc_data.root_ca = crate::utils::storage::Vec::from_slice(req.str.0) + .map_err(|_| ErrorCode::BufferTooSmall)?; Ok(()) }) diff --git a/rs-matter/src/data_model/subscriptions.rs b/rs-matter/src/data_model/subscriptions.rs index e207c511..b6220d7f 100644 --- a/rs-matter/src/data_model/subscriptions.rs +++ b/rs-matter/src/data_model/subscriptions.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::cell::RefCell; use core::num::NonZeroU8; use embassy_sync::blocking_mutex::raw::NoopRawMutex; @@ -23,7 +22,9 @@ use embassy_time::Instant; use portable_atomic::{AtomicU32, Ordering}; -use crate::utils::notification::Notification; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; +use crate::utils::sync::Notification; struct Subscription { fabric_idx: NonZeroU8, @@ -62,7 +63,7 @@ impl Subscription { /// Additional subscriptions are rejected by the data model with a "resource exhausted" IM status message. pub struct Subscriptions { next_subscription_id: AtomicU32, - subscriptions: RefCell>, + subscriptions: RefCell>, pub(crate) notification: Notification, } @@ -78,11 +79,20 @@ impl Subscriptions { pub const fn new() -> Self { Self { next_subscription_id: AtomicU32::new(1), - subscriptions: RefCell::new(heapless::Vec::new()), + subscriptions: RefCell::new(crate::utils::storage::Vec::new()), notification: Notification::new(), } } + /// Create an in-place initializer for the instance. + pub fn init() -> impl Init { + init!(Self { + next_subscription_id: AtomicU32::new(1), + subscriptions <- RefCell::init(crate::utils::storage::Vec::init()), + notification: Notification::new(), + }) + } + /// Notify the instance that some data in the data model has changed and that it should re-evaluate the subscriptions /// and report on those that concern the changed data. /// diff --git a/rs-matter/src/data_model/system_model/access_control.rs b/rs-matter/src/data_model/system_model/access_control.rs index aa7daeb0..7ee58133 100644 --- a/rs-matter/src/data_model/system_model/access_control.rs +++ b/rs-matter/src/data_model/system_model/access_control.rs @@ -134,7 +134,12 @@ impl AccessControlCluster { Attributes::Acl(_) => { writer.start_array(AttrDataWriter::TAG)?; acl_mgr.for_each_acl(|entry| { - if !attr.fab_filter || attr.fab_idx == entry.fab_idx.get() { + if !attr.fab_filter + || entry + .fab_idx + .map(|fi| fi.get() == attr.fab_idx) + .unwrap_or(false) + { entry.to_tlv(&mut writer, TagType::Anonymous)?; } @@ -184,7 +189,7 @@ impl AccessControlCluster { let mut acl_entry = AclEntry::from_tlv(data)?; info!("ACL {:?}", acl_entry); // Overwrite the fabric index with our accessing fabric index - acl_entry.fab_idx = fab_idx; + acl_entry.fab_idx = Some(fab_idx); if let ListOperation::EditItem(index) = op { acl_mgr.edit(*index as u8, fab_idx, acl_entry)?; @@ -230,7 +235,7 @@ mod tests { use crate::data_model::system_model::access_control::Dataver; use crate::interaction_model::messages::ib::ListOperation; use crate::tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV}; - use crate::utils::writebuf::WriteBuf; + use crate::utils::storage::WriteBuf; use super::AccessControlCluster; diff --git a/rs-matter/src/fabric.rs b/rs-matter/src/fabric.rs index cf71a63c..f6d8ba19 100644 --- a/rs-matter/src/fabric.rs +++ b/rs-matter/src/fabric.rs @@ -19,18 +19,19 @@ use core::fmt::Write; use core::num::NonZeroU8; use byteorder::{BigEndian, ByteOrder, LittleEndian}; -use heapless::{String, Vec}; + +use heapless::String; + use log::info; -use crate::{ - cert::{Cert, MAX_CERT_TLV_LEN}, - crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, - error::{Error, ErrorCode}, - group_keys::KeySet, - mdns::{Mdns, ServiceMode}, - tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, - utils::writebuf::WriteBuf, -}; +use crate::cert::{Cert, MAX_CERT_TLV_LEN}; +use crate::crypto::{self, hkdf_sha256, HmacSha256, KeyPair}; +use crate::error::{Error, ErrorCode}; +use crate::group_keys::KeySet; +use crate::mdns::{Mdns, ServiceMode}; +use crate::tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::utils::init::{init, Init}; +use crate::utils::storage::{Vec, WriteBuf}; const COMPRESSED_FABRIC_ID_LEN: usize = 8; @@ -64,9 +65,9 @@ pub struct Fabric { impl Fabric { pub fn new( key_pair: KeyPair, - root_ca: heapless::Vec, - icac: Option>, - noc: heapless::Vec, + root_ca: Vec, + icac: Option>, + noc: Vec, ipk: &[u8], vendor_id: u16, label: &str, @@ -206,6 +207,13 @@ impl FabricMgr { } } + pub fn init() -> impl Init { + init!(Self { + fabrics <- FabricEntries::init(), + changed: false, + }) + } + pub fn load(&mut self, data: &[u8], mdns: &dyn Mdns) -> Result<(), Error> { for fabric in self.fabrics.iter().flatten() { mdns.remove(&fabric.mdns_service_name)?; @@ -213,7 +221,7 @@ impl FabricMgr { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - tlv::from_tlv(&mut self.fabrics, &root)?; + tlv::vec_from_tlv(&mut self.fabrics, &root)?; for fabric in self.fabrics.iter().flatten() { mdns.add(&fabric.mdns_service_name, ServiceMode::Commissioned)?; diff --git a/rs-matter/src/interaction_model/core.rs b/rs-matter/src/interaction_model/core.rs index 772b0411..7266e0a3 100644 --- a/rs-matter/src/interaction_model/core.rs +++ b/rs-matter/src/interaction_model/core.rs @@ -21,7 +21,7 @@ use crate::{ error::*, tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, transport::exchange::MessageMeta, - utils::{epoch::Epoch, writebuf::WriteBuf}, + utils::{epoch::Epoch, storage::WriteBuf}, }; use num::FromPrimitive; use num_derive::FromPrimitive; diff --git a/rs-matter/src/mdns.rs b/rs-matter/src/mdns.rs index a697275b..8852c1bb 100644 --- a/rs-matter/src/mdns.rs +++ b/rs-matter/src/mdns.rs @@ -17,7 +17,9 @@ use core::fmt::Write; -use crate::{data_model::cluster_basic_information::BasicInfoConfig, error::Error}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::Error; +use crate::utils::init::{init, Init}; #[cfg(all(feature = "std", target_os = "macos"))] #[path = "mdns/astro.rs"] @@ -98,48 +100,66 @@ pub enum MdnsService<'a> { Provided(&'a dyn Mdns), } -impl<'a> MdnsService<'a> { - pub(crate) const fn new_impl( - &self, +pub(crate) struct MdnsImpl<'a> { + service: MdnsService<'a>, + builtin: builtin::MdnsImpl<'a>, +} + +impl<'a> MdnsImpl<'a> { + pub(crate) const fn new( + service: MdnsService<'a>, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, - ) -> MdnsImpl<'a> { - match self { - Self::Disabled => MdnsImpl::Disabled, - Self::Builtin => MdnsImpl::Builtin(builtin::MdnsImpl::new(dev_det, matter_port)), - Self::Provided(mdns) => MdnsImpl::Provided(*mdns), + ) -> Self { + Self { + service, + builtin: builtin::MdnsImpl::new(dev_det, matter_port), } } -} -pub(crate) enum MdnsImpl<'a> { - Disabled, - Builtin(builtin::MdnsImpl<'a>), - Provided(&'a dyn Mdns), + pub(crate) fn init( + service: MdnsService<'a>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + ) -> impl Init { + init!(Self { + service, + builtin <- builtin::MdnsImpl::init(dev_det, matter_port), + }) + } + + #[allow(unused)] + pub(crate) fn builtin(&self) -> Option<&builtin::MdnsImpl> { + matches!(self.service, MdnsService::Builtin).then_some(&self.builtin) + } + + pub(crate) fn update(&mut self, service: MdnsService<'a>) { + self.service = service; + } } impl<'a> Mdns for MdnsImpl<'a> { fn reset(&self) { - match self { - Self::Disabled => {} - Self::Builtin(mdns) => mdns.reset(), - Self::Provided(mdns) => mdns.reset(), + match self.service { + MdnsService::Disabled => {} + MdnsService::Builtin => self.builtin.reset(), + MdnsService::Provided(mdns) => mdns.reset(), } } fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { - match self { - Self::Disabled => Ok(()), - Self::Builtin(mdns) => mdns.add(service, mode), - Self::Provided(mdns) => mdns.add(service, mode), + match self.service { + MdnsService::Disabled => Ok(()), + MdnsService::Builtin => self.builtin.add(service, mode), + MdnsService::Provided(mdns) => mdns.add(service, mode), } } fn remove(&self, service: &str) -> Result<(), Error> { - match self { - Self::Disabled => Ok(()), - Self::Builtin(mdns) => mdns.remove(service), - Self::Provided(mdns) => mdns.remove(service), + match self.service { + MdnsService::Disabled => Ok(()), + MdnsService::Builtin => self.builtin.remove(service), + MdnsService::Provided(mdns) => mdns.remove(service), } } } diff --git a/rs-matter/src/mdns/astro.rs b/rs-matter/src/mdns/astro.rs index 06c67eb8..413e0614 100644 --- a/rs-matter/src/mdns/astro.rs +++ b/rs-matter/src/mdns/astro.rs @@ -1,15 +1,13 @@ -use core::cell::RefCell; - use std::collections::BTreeMap; use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; use log::info; -use crate::{ - data_model::cluster_basic_information::BasicInfoConfig, - error::{Error, ErrorCode}, -}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::{Error, ErrorCode}; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; use super::ServiceMode; @@ -28,6 +26,14 @@ impl<'a> MdnsImpl<'a> { } } + pub fn init(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> impl Init { + init!(Self { + dev_det, + matter_port, + services <- RefCell::init(BTreeMap::new()), + }) + } + pub fn reset(&self) { self.services.borrow_mut().clear(); } diff --git a/rs-matter/src/mdns/builtin.rs b/rs-matter/src/mdns/builtin.rs index 235f1bc3..89207e6e 100644 --- a/rs-matter/src/mdns/builtin.rs +++ b/rs-matter/src/mdns/builtin.rs @@ -1,4 +1,3 @@ -use core::cell::RefCell; use core::net::IpAddr; use core::pin::pin; @@ -6,6 +5,7 @@ use embassy_futures::select::select; use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; use embassy_sync::mutex::Mutex; use embassy_time::{Duration, Timer}; + use log::{info, warn}; use crate::data_model::cluster_basic_information::BasicInfoConfig; @@ -14,8 +14,12 @@ use crate::transport::network::{ Address, Ipv4Addr, Ipv6Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV4, SocketAddrV6, }; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; use crate::utils::rand::Rand; -use crate::utils::{buf::BufferAccess, notification::Notification, select::Coalesce}; +use crate::utils::select::Coalesce; +use crate::utils::storage::pooled::BufferAccess; +use crate::utils::sync::Notification; use super::{Service, ServiceMode}; @@ -37,7 +41,7 @@ pub const MDNS_PORT: u16 = 5353; pub struct MdnsImpl<'a> { dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, - services: RefCell, ServiceMode), 4>>, + services: RefCell, ServiceMode), 4>>, notification: Notification, } @@ -47,11 +51,20 @@ impl<'a> MdnsImpl<'a> { Self { dev_det, matter_port, - services: RefCell::new(heapless::Vec::new()), + services: RefCell::new(crate::utils::storage::Vec::new()), notification: Notification::new(), } } + pub fn init(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> impl Init { + init!(Self { + dev_det, + matter_port, + services <- RefCell::init(crate::utils::storage::Vec::init()), + notification: Notification::new(), + }) + } + pub fn reset(&self) { self.services.borrow_mut().clear(); } @@ -136,11 +149,11 @@ impl<'a> MdnsImpl<'a> { select(&mut notification, &mut timeout).await; - for addr in core::iter::once(SocketAddr::V4(SocketAddrV4::new( - MDNS_IPV4_BROADCAST_ADDR, - MDNS_PORT, - ))) - .chain( + for addr in Iterator::chain( + core::iter::once(SocketAddr::V4(SocketAddrV4::new( + MDNS_IPV4_BROADCAST_ADDR, + MDNS_PORT, + ))), interface .map(|interface| { SocketAddr::V6(SocketAddrV6::new( diff --git a/rs-matter/src/mdns/proto.rs b/rs-matter/src/mdns/proto.rs index bd2ed06b..2ba952dd 100644 --- a/rs-matter/src/mdns/proto.rs +++ b/rs-matter/src/mdns/proto.rs @@ -1064,7 +1064,7 @@ mod tests { assert_eq!(t, str); } - while let Some(t) = txt.next() { + for t in txt { if !t.is_empty() { panic!("Unexpected TXT string {:?} for {}", t, expected.owner); } diff --git a/rs-matter/src/mdns/zeroconf.rs b/rs-matter/src/mdns/zeroconf.rs index 5a04ea04..db5b45d1 100644 --- a/rs-matter/src/mdns/zeroconf.rs +++ b/rs-matter/src/mdns/zeroconf.rs @@ -1,5 +1,3 @@ -use core::cell::RefCell; - use std::collections::BTreeMap; use std::sync::mpsc::{sync_channel, SyncSender}; @@ -7,10 +5,10 @@ use log::error; use zeroconf::{prelude::TEventLoop, service::TMdnsService, txt_record::TTxtRecord, ServiceType}; -use crate::{ - data_model::cluster_basic_information::BasicInfoConfig, - error::{Error, ErrorCode}, -}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::{Error, ErrorCode}; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; use super::ServiceMode; @@ -39,6 +37,14 @@ impl<'a> MdnsImpl<'a> { } } + pub fn init(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> impl Init { + init!(Self { + dev_det, + matter_port, + services <- RefCell::init(BTreeMap::new()), + }) + } + pub fn reset(&self) { self.services.borrow_mut().clear(); } diff --git a/rs-matter/src/pairing/qr.rs b/rs-matter/src/pairing/qr.rs index a5c08650..44c14af1 100644 --- a/rs-matter/src/pairing/qr.rs +++ b/rs-matter/src/pairing/qr.rs @@ -20,7 +20,7 @@ use qrcodegen_no_heap::{QrCode, QrCodeEcc, Version}; use crate::{ error::ErrorCode, tlv::{ElementType, TLVElement, TLVWriter, TagType, ToTLV}, - utils::writebuf::WriteBuf, + utils::storage::WriteBuf, }; use super::{ diff --git a/rs-matter/src/persist.rs b/rs-matter/src/persist.rs index a25b13a0..4936c544 100644 --- a/rs-matter/src/persist.rs +++ b/rs-matter/src/persist.rs @@ -19,58 +19,96 @@ pub use fileio::*; #[cfg(feature = "std")] pub mod fileio { + use core::mem::MaybeUninit; + use std::fs; use std::io::{Read, Write}; - use std::path::{Path, PathBuf}; + use std::path::Path; use log::info; use crate::error::{Error, ErrorCode}; + use crate::utils::init::{init, Init}; use crate::Matter; - pub struct Psm<'a> { - matter: &'a Matter<'a>, - dir: PathBuf, - buf: [u8; 4096], + pub struct Psm { + buf: MaybeUninit<[u8; N]>, + } + + impl Default for Psm { + fn default() -> Self { + Self::new() + } } - impl<'a> Psm<'a> { + impl Psm { #[inline(always)] - pub fn new(matter: &'a Matter<'a>, dir: PathBuf) -> Result { - fs::create_dir_all(&dir)?; + pub const fn new() -> Self { + Self { + buf: MaybeUninit::uninit(), + } + } - info!("Persisting from/to {}", dir.display()); + pub fn init() -> impl Init { + init!(Self { + buf <- crate::utils::init::zeroed(), + }) + } - let mut buf = [0; 4096]; + pub fn load(&mut self, dir: &Path, matter: &Matter) -> Result<(), Error> { + fs::create_dir_all(dir)?; - if let Some(data) = Self::load(&dir, "acls", &mut buf)? { + if let Some(data) = Self::load_key(dir, "acls", unsafe { self.buf.assume_init_mut() })? + { matter.load_acls(data)?; } - if let Some(data) = Self::load(&dir, "fabrics", &mut buf)? { + if let Some(data) = + Self::load_key(dir, "fabrics", unsafe { self.buf.assume_init_mut() })? + { matter.load_fabrics(data)?; } - Ok(Self { matter, dir, buf }) + Ok(()) } - pub async fn run(&mut self) -> Result<(), Error> { - loop { - self.matter.wait_changed().await; + pub fn store(&mut self, dir: &Path, matter: &Matter) -> Result<(), Error> { + if matter.is_changed() { + fs::create_dir_all(dir)?; - if self.matter.is_changed() { - if let Some(data) = self.matter.store_acls(&mut self.buf)? { - Self::store(&self.dir, "acls", data)?; - } + if let Some(data) = matter.store_acls(unsafe { self.buf.assume_init_mut() })? { + Self::store_key(dir, "acls", data)?; + } - if let Some(data) = self.matter.store_fabrics(&mut self.buf)? { - Self::store(&self.dir, "fabrics", data)?; - } + if let Some(data) = matter.store_fabrics(unsafe { self.buf.assume_init_mut() })? { + Self::store_key(dir, "fabrics", data)?; } } + + Ok(()) + } + + pub async fn run>( + &mut self, + dir: P, + matter: &Matter<'_>, + ) -> Result<(), Error> { + let dir = dir.as_ref(); + + self.load(dir, matter)?; + + loop { + matter.wait_changed().await; + + self.store(dir, matter)?; + } } - fn load<'b>(dir: &Path, key: &str, buf: &'b mut [u8]) -> Result, Error> { + fn load_key<'b>( + dir: &Path, + key: &str, + buf: &'b mut [u8], + ) -> Result, Error> { let path = dir.join(key); match fs::File::open(path) { @@ -101,7 +139,7 @@ pub mod fileio { } } - fn store(dir: &Path, key: &str, data: &[u8]) -> Result<(), Error> { + fn store_key(dir: &Path, key: &str, data: &[u8]) -> Result<(), Error> { let path = dir.join(key); let mut file = fs::File::create(path)?; diff --git a/rs-matter/src/respond.rs b/rs-matter/src/respond.rs index fd56fe7b..2abf6f46 100644 --- a/rs-matter/src/respond.rs +++ b/rs-matter/src/respond.rs @@ -31,8 +31,8 @@ use crate::interaction_model::core::PROTO_ID_INTERACTION_MODEL; use crate::secure_channel::busy::BusySecureChannel; use crate::secure_channel::core::SecureChannel; use crate::transport::exchange::Exchange; -use crate::utils::buf::BufferAccess; use crate::utils::select::Coalesce; +use crate::utils::storage::pooled::BufferAccess; use crate::Matter; /// Send a busy response if - after that many ms - the exchange diff --git a/rs-matter/src/secure_channel/case.rs b/rs-matter/src/secure_channel/case.rs index 1ecfd6cd..e1c4fe7c 100644 --- a/rs-matter/src/secure_channel/case.rs +++ b/rs-matter/src/secure_channel/case.rs @@ -31,7 +31,7 @@ use crate::{ exchange::Exchange, session::{NocCatIds, ReservedSession, SessionMode}, }, - utils::{rand::Rand, writebuf::WriteBuf}, + utils::{rand::Rand, storage::WriteBuf}, }; #[derive(Debug, Clone)] diff --git a/rs-matter/src/secure_channel/common.rs b/rs-matter/src/secure_channel/common.rs index 4d4d3921..8b559153 100644 --- a/rs-matter/src/secure_channel/common.rs +++ b/rs-matter/src/secure_channel/common.rs @@ -19,7 +19,7 @@ use num_derive::FromPrimitive; use crate::error::Error; use crate::transport::exchange::{Exchange, MessageMeta}; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; use super::status_report::{GeneralCode, StatusReport}; diff --git a/rs-matter/src/secure_channel/pake.rs b/rs-matter/src/secure_channel/pake.rs index d00434c9..9e7297fa 100644 --- a/rs-matter/src/secure_channel/pake.rs +++ b/rs-matter/src/secure_channel/pake.rs @@ -17,23 +17,25 @@ use core::{fmt::Write, time::Duration}; -use super::{ - common::SCStatusCodes, - spake2p::{Spake2P, VerifierData, MAX_SALT_SIZE_BYTES}, +use log::{error, info}; + +use crate::crypto; +use crate::error::{Error, ErrorCode}; +use crate::mdns::{Mdns, ServiceMode}; +use crate::secure_channel::common::{complete_with_status, OpCode}; +use crate::tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; +use crate::transport::{ + exchange::{Exchange, ExchangeId}, + session::{ReservedSession, SessionMode}, }; -use crate::{ - crypto, - error::{Error, ErrorCode}, - mdns::{Mdns, ServiceMode}, - secure_channel::common::{complete_with_status, OpCode}, - tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, - transport::{ - exchange::{Exchange, ExchangeId}, - session::{ReservedSession, SessionMode}, - }, - utils::{epoch::Epoch, rand::Rand}, +use crate::utils::{ + epoch::Epoch, + init::{init, Init}, + rand::Rand, }; -use log::{error, info}; + +use super::common::SCStatusCodes; +use super::spake2p::{Spake2P, VerifierData, MAX_SALT_SIZE_BYTES}; struct PaseSession { mdns_service_name: heapless::String<16>, @@ -58,6 +60,17 @@ impl PaseMgr { } } + pub fn init(epoch: Epoch, rand: Rand) -> impl Init { + // TODO: Optimize in future because `PaseSession` is + // relatively large and we are creating it using stack moves. + init!(Self { + session: None, + timeout: None, + epoch, + rand, + }) + } + pub fn is_pase_session_enabled(&self) -> bool { self.session.is_some() } diff --git a/rs-matter/src/secure_channel/status_report.rs b/rs-matter/src/secure_channel/status_report.rs index d365dd1a..bb433a17 100644 --- a/rs-matter/src/secure_channel/status_report.rs +++ b/rs-matter/src/secure_channel/status_report.rs @@ -19,7 +19,7 @@ use num_derive::FromPrimitive; use crate::{ error::{Error, ErrorCode}, - utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, + utils::storage::{ParseBuf, WriteBuf}, }; #[allow(dead_code)] diff --git a/rs-matter/src/tlv/traits.rs b/rs-matter/src/tlv/traits.rs index acb47c01..2839fb06 100644 --- a/rs-matter/src/tlv/traits.rs +++ b/rs-matter/src/tlv/traits.rs @@ -79,6 +79,24 @@ pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>( Ok(()) } +pub fn vec_from_tlv<'a, T: FromTLV<'a>, const N: usize>( + vec: &mut crate::utils::storage::Vec, + t: &TLVElement<'a>, +) -> Result<(), Error> { + vec.clear(); + + t.confirm_array()?; + + if let Some(tlv_iter) = t.enter() { + for element in tlv_iter { + vec.push(T::from_tlv(&element)?) + .map_err(|_| ErrorCode::NoSpace)?; + } + } + + Ok(()) +} + macro_rules! fromtlv_for { ($($t:ident)*) => { $( @@ -237,6 +255,19 @@ impl ToTLV for heapless::Vec { } } +/// Implements the Owned version of Octet String +impl FromTLV<'_> for crate::utils::storage::Vec { + fn from_tlv(t: &TLVElement) -> Result, Error> { + crate::utils::storage::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into()) + } +} + +impl ToTLV for crate::utils::storage::Vec { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.str16(tag, self.as_slice()) + } +} + /// Implements the Owned version of UTF String impl FromTLV<'_> for heapless::String { fn from_tlv(t: &TLVElement) -> Result, Error> { @@ -538,7 +569,8 @@ macro_rules! bitflags_tlv { #[cfg(test)] mod tests { use super::{FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; - use crate::{tlv::TLVList, utils::writebuf::WriteBuf}; + use crate::tlv::TLVList; + use crate::utils::storage::WriteBuf; use rs_matter_macros::{FromTLV, ToTLV}; #[derive(ToTLV)] diff --git a/rs-matter/src/tlv/writer.rs b/rs-matter/src/tlv/writer.rs index 45c60c97..e6af4efe 100644 --- a/rs-matter/src/tlv/writer.rs +++ b/rs-matter/src/tlv/writer.rs @@ -15,10 +15,13 @@ * limitations under the License. */ -use super::{TagType, TAG_SHIFT_BITS, TAG_SIZE_MAP}; -use crate::{error::*, utils::writebuf::WriteBuf}; use log::error; +use crate::error::*; +use crate::utils::storage::WriteBuf; + +use super::{TagType, TAG_SHIFT_BITS, TAG_SIZE_MAP}; + #[allow(dead_code)] enum WriteElementType { S8 = 0, @@ -273,7 +276,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> { #[cfg(test)] mod tests { use super::{TLVWriter, TagType}; - use crate::utils::writebuf::WriteBuf; + use crate::utils::storage::WriteBuf; #[test] fn test_write_success() { diff --git a/rs-matter/src/transport/core.rs b/rs-matter/src/transport/core.rs index 3559e35e..b1e386bd 100644 --- a/rs-matter/src/transport/core.rs +++ b/rs-matter/src/transport/core.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::cell::RefCell; use core::fmt::{self, Display}; use core::ops::{Deref, DerefMut}; use core::pin::pin; @@ -26,21 +25,19 @@ use embassy_time::Timer; use log::{debug, error, info, trace, warn}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; use crate::error::{Error, ErrorCode}; -use crate::mdns::MdnsImpl; +use crate::mdns::{MdnsImpl, MdnsService}; use crate::secure_channel::common::{sc_write, OpCode, SCStatusCodes, PROTO_ID_SECURE_CHANNEL}; use crate::secure_channel::status_report::StatusReport; use crate::tlv::TLVList; -use crate::utils::buf::BufferAccess; -use crate::utils::{ - epoch::Epoch, - ifmutex::{IfMutex, IfMutexGuard}, - notification::Notification, - parsebuf::ParseBuf, - rand::Rand, - select::Coalesce, - writebuf::WriteBuf, -}; +use crate::utils::cell::RefCell; +use crate::utils::epoch::Epoch; +use crate::utils::init::{init, Init}; +use crate::utils::rand::Rand; +use crate::utils::select::Coalesce; +use crate::utils::storage::{pooled::BufferAccess, ParseBuf, WriteBuf}; +use crate::utils::sync::{IfMutex, IfMutexGuard, Notification}; use crate::{Matter, MATTER_PORT}; use super::exchange::{Exchange, ExchangeId, ExchangeState, MessageMeta, ResponderState, Role}; @@ -86,33 +83,58 @@ pub struct TransportMgr<'m> { impl<'m> TransportMgr<'m> { #[inline(always)] - pub(crate) const fn new(mdns: MdnsImpl<'m>, epoch: Epoch, rand: Rand) -> Self { + pub(crate) const fn new( + service: MdnsService<'m>, + dev_det: &'m BasicInfoConfig<'m>, + matter_port: u16, + epoch: Epoch, + rand: Rand, + ) -> Self { Self { rx: IfMutex::new(Packet::new()), tx: IfMutex::new(Packet::new()), dropped: Notification::new(), session_removed: Notification::new(), session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), - mdns, + mdns: MdnsImpl::new(service, dev_det, matter_port), rand, } } - pub(crate) fn replace_mdns(&mut self, mdns: MdnsImpl<'m>) { - self.mdns = mdns; + pub(crate) fn init( + service: MdnsService<'m>, + dev_det: &'m BasicInfoConfig<'m>, + matter_port: u16, + epoch: Epoch, + rand: Rand, + ) -> impl Init { + init!(Self { + rx <- IfMutex::init(Packet::init()), + tx <- IfMutex::init(Packet::init()), + dropped: Notification::new(), + session_removed: Notification::new(), + session_mgr <- RefCell::init(SessionMgr::init(epoch, rand)), + mdns <- MdnsImpl::init(service, dev_det, matter_port), + rand, + }) } + pub(crate) fn replace_mdns(&mut self, mdns: MdnsService<'m>) { + self.mdns.update(mdns); + } + + // TODO #[cfg(all(feature = "large-buffers", feature = "alloc"))] pub fn initialize_buffers(&self) -> Result<(), Error> { let mut rx = self.rx.try_lock().map_err(|_| ErrorCode::InvalidState)?; let mut tx = self.tx.try_lock().map_err(|_| ErrorCode::InvalidState)?; if rx.buf.0.is_none() { - rx.buf.0 = Some(alloc::boxed::Box::new(heapless::Vec::new())); + rx.buf.0 = Some(alloc::boxed::Box::new(crate::utils::storage::Vec::new())); } if tx.buf.0.is_none() { - tx.buf.0 = Some(alloc::boxed::Box::new(heapless::Vec::new())); + tx.buf.0 = Some(alloc::boxed::Box::new(crate::utils::storage::Vec::new())); } Ok(()) @@ -274,7 +296,7 @@ impl<'m> TransportMgr<'m> { { info!("Running Matter built-in mDNS service"); - if let MdnsImpl::Builtin(mdns) = &self.mdns { + if let Some(mdns) = self.mdns.builtin() { mdns.run( send, recv, @@ -1046,6 +1068,15 @@ impl Packet { } } + pub(crate) fn init() -> impl Init { + init!(Self { + peer: Address::new(), + header: PacketHdr::new(), + buf <- PacketBuffer::init(), + payload_start: 0, + }) + } + pub fn display<'a>(peer: &'a Address, header: &'a PacketHdr) -> impl Display + 'a { struct PacketInfo<'a>(&'a Address, &'a PacketHdr); @@ -1116,54 +1147,66 @@ impl Display for Packet { // // This type is only known and used by `TransportMgr` and the `exchange` module #[cfg(all(feature = "large-buffers", feature = "alloc"))] -pub(crate) struct PacketBuffer(Option>>); +pub(crate) struct PacketBuffer( + Option>>, +); // The buffer used inside the pair of RX and TX `Packet` instances // When the either of the `alloc` and `large-buffers` features is not enabled, the buffer payload is allocated inline // // This type is only known and used by `TransportMgr` and the `exchange` module #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] -pub(crate) struct PacketBuffer(heapless::Vec); +pub(crate) struct PacketBuffer { + buffer: crate::utils::storage::Vec, +} impl PacketBuffer { #[cfg(all(feature = "large-buffers", feature = "alloc"))] pub const fn new() -> Self { - Self(None) + Self { buffer: None } } #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] pub const fn new() -> Self { - Self(heapless::Vec::new()) + Self { + buffer: crate::utils::storage::Vec::new(), + } + } + + pub fn init() -> impl Init { + init!(Self { + buffer <- crate::utils::storage::Vec::init(), + }) } #[cfg(all(feature = "large-buffers", feature = "alloc"))] - pub fn buf_mut(&mut self) -> &mut heapless::Vec { + pub fn buf_mut(&mut self) -> &mut crate::utils::storage::Vec { &mut *self - .0 + .buffer .as_mut() .expect("Buffer is not allocated. Did you forget to call `initialize_buffers`?") } #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] - pub fn buf_mut(&mut self) -> &mut heapless::Vec { - &mut self.0 + pub fn buf_mut(&mut self) -> &mut crate::utils::storage::Vec { + &mut self.buffer } #[cfg(all(feature = "large-buffers", feature = "alloc"))] - pub fn buf_ref(&self) -> &heapless::Vec { - self.0 + pub fn buf_ref(&self) -> crate::utils::storage::Vec { + self.buffer .as_ref() .expect("Buffer is not allocated. Did you forget to call `initialize_buffers`?") } #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] - pub fn buf_ref(&self) -> &heapless::Vec { - &self.0 + pub fn buf_ref(&self) -> &crate::utils::storage::Vec { + &self.buffer } } impl Deref for PacketBuffer { - type Target = heapless::Vec; + type Target = crate::utils::storage::Vec; fn deref(&self) -> &Self::Target { self.buf_ref() diff --git a/rs-matter/src/transport/exchange.rs b/rs-matter/src/transport/exchange.rs index 4d800251..5ec3d10a 100644 --- a/rs-matter/src/transport/exchange.rs +++ b/rs-matter/src/transport/exchange.rs @@ -27,7 +27,8 @@ use crate::acl::Accessor; use crate::error::{Error, ErrorCode}; use crate::interaction_model::{self, core::PROTO_ID_INTERACTION_MODEL}; use crate::secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL}; -use crate::utils::{epoch::Epoch, writebuf::WriteBuf}; +use crate::utils::epoch::Epoch; +use crate::utils::storage::WriteBuf; use crate::Matter; use super::core::{Packet, PacketAccess, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; diff --git a/rs-matter/src/transport/network/btp.rs b/rs-matter/src/transport/network/btp.rs index 06724ba6..23cd5dc4 100644 --- a/rs-matter/src/transport/network/btp.rs +++ b/rs-matter/src/transport/network/btp.rs @@ -27,13 +27,15 @@ use embassy_time::{Duration, Instant, Timer}; use log::trace; use context::LockError; + use session::{BTP_ACK_TIMEOUT_SECS, BTP_CONN_IDLE_TIMEOUT_SECS}; use crate::data_model::cluster_basic_information::BasicInfoConfig; use crate::error::{Error, ErrorCode}; use crate::transport::network::{Address, BtAddr, NetworkReceive, NetworkSend}; -use crate::utils::ifmutex::IfMutex; +use crate::utils::init::{init, Init}; use crate::utils::select::Coalesce; +use crate::utils::sync::IfMutex; use crate::CommissioningData; pub use context::{BtpContext, MAX_BTP_SESSIONS}; @@ -65,7 +67,7 @@ pub(crate) const MAX_MTU: u16 = (MAX_BTP_SEGMENT_SIZE + GATT_HEADER_SIZE) as u16 pub struct Btp { gatt: T, context: C, - send_buf: IfMutex>, + send_buf: IfMutex>, ack_timeout_secs: u16, conn_idle_timeout_secs: u16, _mutex: PhantomData, @@ -81,6 +83,10 @@ where pub fn new_builtin(context: C) -> Self { Self::new(BuiltinGattPeripheral::new(None), context) } + + pub fn init_builtin>(context: C) -> impl Init { + Self::init(BuiltinGattPeripheral::new(None), context) + } } impl Btp @@ -111,13 +117,27 @@ where Self { gatt, context, - send_buf: IfMutex::new(heapless::Vec::new()), + send_buf: IfMutex::new(crate::utils::storage::Vec::new()), ack_timeout_secs, conn_idle_timeout_secs, _mutex: PhantomData, } } + /// Create an in-place initializer for a BTP object with the provided + /// `GattPeripheral` trait in-place initializer and and with the provided BTP + /// `context` in-place initializer. + pub fn init, IC: Init>(gatt: IT, context: IC) -> impl Init { + init!(Self { + gatt <- gatt, + context <- context, + send_buf <- IfMutex::init(crate::utils::storage::Vec::init()), + ack_timeout_secs: BTP_ACK_TIMEOUT_SECS, + conn_idle_timeout_secs: BTP_CONN_IDLE_TIMEOUT_SECS, + _mutex: PhantomData, + }) + } + /// Run the BTP protocol /// /// While all sending and receiving of Matter packets (a.k.a. BTP SDUs) is done via the `recv` and `send` methods @@ -291,7 +311,7 @@ where /// in case it is used by another operation. async fn send_buf( &self, - ) -> impl DerefMut> + '_ { + ) -> impl DerefMut> + '_ { let mut buf = self.send_buf.lock().await; // Unwrap is safe because the max size of the buffer is MAX_PDU_SIZE diff --git a/rs-matter/src/transport/network/btp/context.rs b/rs-matter/src/transport/network/btp/context.rs index 117439e9..8634dff9 100644 --- a/rs-matter/src/transport/network/btp/context.rs +++ b/rs-matter/src/transport/network/btp/context.rs @@ -15,14 +15,16 @@ * limitations under the License. */ -use core::cell::RefCell; +use embassy_sync::blocking_mutex::raw::RawMutex; -use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; use log::{error, info, trace, warn}; use crate::error::{Error, ErrorCode}; use crate::transport::network::BtAddr; -use crate::utils::notification::Notification; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init, IntoFallibleInit}; +use crate::utils::sync::blocking::Mutex; +use crate::utils::sync::Notification; use super::{session::Session, GattPeripheralEvent}; @@ -182,7 +184,7 @@ pub struct BtpContext where M: RawMutex, { - pub(crate) sessions: Mutex>>, + pub(crate) sessions: Mutex>>, pub(crate) handshake_notif: Notification, pub(crate) available_notif: Notification, pub(crate) recv_notif: Notification, @@ -207,7 +209,7 @@ where #[inline(always)] pub const fn new() -> Self { Self { - sessions: Mutex::new(RefCell::new(heapless::Vec::new())), + sessions: Mutex::new(RefCell::new(crate::utils::storage::Vec::new())), handshake_notif: Notification::new(), available_notif: Notification::new(), recv_notif: Notification::new(), @@ -215,6 +217,18 @@ where send_notif: Notification::new(), } } + + /// Create a BTP context in-place initializer. + pub fn init() -> impl Init { + init!(Self { + sessions <- Mutex::init(RefCell::init(crate::utils::storage::Vec::init())), + handshake_notif: Notification::new(), + available_notif: Notification::new(), + recv_notif: Notification::new(), + ack_notif: Notification::new(), + send_notif: Notification::new(), + }) + } } impl BtpContext @@ -251,9 +265,11 @@ where warn!("Too many BTP sessions, dropping a handshake request from address {address}"); } else { // Unwrap is safe because we checked the length above - sessions - .push(Session::process_rx_handshake(address, data, gatt_mtu)?) - .unwrap(); + sessions.push_init( + Session::process_rx_handshake(address, data, gatt_mtu)?.into_fallible::(), + || ErrorCode::NoSpace.into(), + ) + .unwrap(); } Ok(()) diff --git a/rs-matter/src/transport/network/btp/gatt/bluer.rs b/rs-matter/src/transport/network/btp/gatt/bluer.rs index ffac0740..525a1e7c 100644 --- a/rs-matter/src/transport/network/btp/gatt/bluer.rs +++ b/rs-matter/src/transport/network/btp/gatt/bluer.rs @@ -16,6 +16,7 @@ */ use core::iter::once; +use core::ptr::addr_of_mut; use alloc::sync::Arc; @@ -36,12 +37,14 @@ use log::{info, trace, warn}; use tokio::io::AsyncWriteExt; use tokio_stream::StreamExt; +use crate::error::{Error, ErrorCode}; use crate::transport::network::btp::MIN_MTU; -use crate::{ - error::{Error, ErrorCode}, - transport::network::{btp::context::MAX_BTP_SESSIONS, BtAddr}, - utils::{ifmutex::IfMutex, select::Coalesce, signal::Signal, std_mutex::StdRawMutex}, -}; +use crate::transport::network::{btp::context::MAX_BTP_SESSIONS, BtAddr}; +use crate::utils::init::{init_from_closure, Init}; +use crate::utils::select::Coalesce; +use crate::utils::sync::blocking::raw::StdRawMutex; +use crate::utils::sync::IfMutex; +use crate::utils::sync::Signal; use super::{AdvData, GattPeripheral, GattPeripheralEvent}; use super::{C1_CHARACTERISTIC_UUID, C2_CHARACTERISTIC_UUID, MATTER_BLE_SERVICE_UUID}; @@ -61,6 +64,16 @@ struct GattState { notifiers_listen_allowed: Signal, } +impl GattState { + pub fn new(adapter_name: Option<&str>) -> Self { + Self { + adapter_name: adapter_name.map(|name| name.into()), + notifiers: IfMutex::new(heapless::Vec::new()), + notifiers_listen_allowed: Signal::new(true), + } + } +} + /// Implements the `GattPeripheral` trait using the BlueZ GATT stack. #[derive(Clone)] pub struct BluerGattPeripheral(Arc); @@ -74,11 +87,23 @@ impl Default for BluerGattPeripheral { impl BluerGattPeripheral { /// Create a new instance. pub fn new(adapter_name: Option<&str>) -> Self { - Self(Arc::new(GattState { - adapter_name: adapter_name.map(|name| name.into()), - notifiers: IfMutex::new(heapless::Vec::new()), - notifiers_listen_allowed: Signal::new(true), - })) + Self(Arc::new(GattState::new(adapter_name))) + } + + /// Create an in-place initializer for the peripheral. + pub fn init(adapter_name: Option<&str>) -> impl Init { + let adapter_name = adapter_name.map(|name| name.to_owned()); + + // We can't (yet) use `pinned-init`'s `InPlaceInit` as it relies on unstable Rust features. + // Not so important specifically for this BlueZ implementation because it is STD-only + // and Linux-only, and there's plenty of stack space on Linux. + unsafe { + init_from_closure(move |slot: *mut BluerGattPeripheral| { + addr_of_mut!((*slot).0).write(Arc::new(GattState::new(adapter_name.as_deref()))); + + Ok(()) + }) + } } /// Runs the GATT peripheral service. diff --git a/rs-matter/src/transport/network/btp/session.rs b/rs-matter/src/transport/network/btp/session.rs index e09dd245..21a3a472 100644 --- a/rs-matter/src/transport/network/btp/session.rs +++ b/rs-matter/src/transport/network/btp/session.rs @@ -26,7 +26,8 @@ use crate::error::{Error, ErrorCode}; use crate::transport::network::btp::session::packet::{HandshakeReq, HandshakeResp}; use crate::transport::network::btp::{GATT_HEADER_SIZE, MAX_MTU, MIN_MTU}; use crate::transport::network::{BtAddr, MAX_RX_PACKET_SIZE}; -use crate::utils::{ringbuf::RingBuf, writebuf::WriteBuf}; +use crate::utils::init::{init, Init}; +use crate::utils::storage::{RingBuf, WriteBuf}; use self::packet::BtpHdr; @@ -172,18 +173,17 @@ struct RecvWindow { } impl RecvWindow { - /// Initialize a new receiving window with the provided window size. - #[inline(always)] - pub const fn new(window_size: u8) -> Self { - Self { - buf: RingBuf::new(), + /// Create an in-place initializer for a receiving window with the provided window size. + pub fn init(window_size: u8) -> impl Init { + init!(Self { + buf <- RingBuf::init(), buf_messages_ct: 0, level: window_size, ack_level: 0, ack_seq: 255, received_at: Instant::MAX, rem_msg_len: 0, - } + }) } /// Process an incoming BTP segment, updating the state of the window accordingly. @@ -375,21 +375,21 @@ pub struct Session { } impl Session { - /// Initialize a new BTP session with the provided address, version, MTU and window size. + /// Return an in-place initializer for a new BTP session with the provided address, version, + /// MTU and window size. /// /// Initializing a session is done based on the data that had arrived in the Handshake Request message, /// written by a remote peer on the `C1` characteristic. - #[inline(always)] - const fn new(address: BtAddr, version: u8, mtu: u16, window_size: u8) -> Self { - Self { + fn init(address: BtAddr, version: u8, mtu: u16, window_size: u8) -> impl Init { + init!(Self { address, state: SessionState::New, version, mtu, window_size, - recv_window: RecvWindow::new(window_size), + recv_window <- RecvWindow::init(window_size), send_window: SendWindow::new(window_size), - } + }) } /// Return the address of the remote peer. @@ -495,7 +495,7 @@ impl Session { address: BtAddr, data: &[u8], gatt_mtu: Option, - ) -> Result { + ) -> Result, Error> { let mut iter = data.iter(); let hdr = BtpHdr::from((&mut iter).copied())?; @@ -537,7 +537,7 @@ impl Session { info!("\n>>>>> (BTP IO) {address} [{hdr}]\nHANDSHAKE REQ {req:?}\nSelected version: {version}, MTU: {mtu}, window size: {window_size}"); - Ok(Self::new(address, version, mtu, window_size)) + Ok(Self::init(address, version, mtu, window_size)) } /// Process an incoming BTP segment of a regular data or ACK type, updating the state of the session accordingly. diff --git a/rs-matter/src/transport/network/btp/session/packet.rs b/rs-matter/src/transport/network/btp/session/packet.rs index 1cfef15a..eb681631 100644 --- a/rs-matter/src/transport/network/btp/session/packet.rs +++ b/rs-matter/src/transport/network/btp/session/packet.rs @@ -22,7 +22,7 @@ use bitflags::bitflags; use log::trace; use crate::error::{Error, ErrorCode}; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; bitflags! { /// Models the flags in the BTP header. diff --git a/rs-matter/src/transport/network/btp/test.rs b/rs-matter/src/transport/network/btp/test.rs index 5c56284b..d9ef4a25 100644 --- a/rs-matter/src/transport/network/btp/test.rs +++ b/rs-matter/src/transport/network/btp/test.rs @@ -26,7 +26,8 @@ use alloc::{vec, vec::Vec}; use embassy_futures::block_on; use crate::secure_channel::spake2p::VerifierData; -use crate::utils::{rand::sys_rand, std_mutex::StdRawMutex}; +use crate::utils::rand::sys_rand; +use crate::utils::sync::blocking::raw::StdRawMutex; use super::*; diff --git a/rs-matter/src/transport/packet.rs b/rs-matter/src/transport/packet.rs index 8c4dd882..a1790373 100644 --- a/rs-matter/src/transport/packet.rs +++ b/rs-matter/src/transport/packet.rs @@ -19,11 +19,9 @@ use core::fmt; use log::trace; -use crate::{ - crypto::AEAD_MIC_LEN_BYTES, - error::Error, - utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, -}; +use crate::crypto::AEAD_MIC_LEN_BYTES; +use crate::error::Error; +use crate::utils::storage::{ParseBuf, WriteBuf}; use super::{ plain_hdr::{self, PlainHdr}, diff --git a/rs-matter/src/transport/plain_hdr.rs b/rs-matter/src/transport/plain_hdr.rs index e2b54812..609e88c3 100644 --- a/rs-matter/src/transport/plain_hdr.rs +++ b/rs-matter/src/transport/plain_hdr.rs @@ -18,8 +18,7 @@ use core::fmt; use crate::error::*; -use crate::utils::parsebuf::ParseBuf; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::{ParseBuf, WriteBuf}; use bitflags::bitflags; use log::trace; diff --git a/rs-matter/src/transport/proto_hdr.rs b/rs-matter/src/transport/proto_hdr.rs index 65b3bb02..b5641071 100644 --- a/rs-matter/src/transport/proto_hdr.rs +++ b/rs-matter/src/transport/proto_hdr.rs @@ -19,8 +19,7 @@ use bitflags::bitflags; use core::fmt; use crate::transport::plain_hdr; -use crate::utils::parsebuf::ParseBuf; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::{ParseBuf, WriteBuf}; use crate::{crypto, error::*}; use log::{trace, warn}; diff --git a/rs-matter/src/transport/session.rs b/rs-matter/src/transport/session.rs index 70110059..e0891268 100644 --- a/rs-matter/src/transport/session.rs +++ b/rs-matter/src/transport/session.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::cell::RefCell; use core::fmt; use core::num::NonZeroU8; use core::time::Duration; @@ -26,10 +25,11 @@ use crate::data_model::sdm::noc::NocData; use crate::error::*; use crate::transport::exchange::ExchangeId; use crate::transport::mrp::ReliableMessage; +use crate::utils::cell::RefCell; use crate::utils::epoch::Epoch; -use crate::utils::parsebuf::ParseBuf; +use crate::utils::init::{init, Init}; use crate::utils::rand::Rand; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::{ParseBuf, WriteBuf}; use crate::Matter; use super::dedup::RxCtrState; @@ -538,16 +538,17 @@ pub struct SessionMgr { next_sess_unique_id: u32, next_sess_id: u16, next_exch_id: u16, - sessions: heapless::Vec, + sessions: crate::utils::storage::Vec, pub(crate) epoch: Epoch, pub(crate) rand: Rand, } impl SessionMgr { + /// Create a new session manager. #[inline(always)] pub const fn new(epoch: Epoch, rand: Rand) -> Self { Self { - sessions: heapless::Vec::new(), + sessions: crate::utils::storage::Vec::new(), next_sess_unique_id: 0, next_sess_id: 1, next_exch_id: 1, @@ -556,6 +557,18 @@ impl SessionMgr { } } + /// Create an in-place initializer for a new session manager. + pub fn init(epoch: Epoch, rand: Rand) -> impl Init { + init!(Self { + sessions <- crate::utils::storage::Vec::init(), + next_sess_unique_id: 0, + next_sess_id: 1, + next_exch_id: 1, + epoch, + rand, + }) + } + pub fn reset(&mut self) { self.sessions.clear(); self.next_sess_id = 1; diff --git a/rs-matter/src/utils/cell.rs b/rs-matter/src/utils/cell.rs new file mode 100644 index 00000000..015d8563 --- /dev/null +++ b/rs-matter/src/utils/cell.rs @@ -0,0 +1,1094 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! A modification of the Rust `RefCell` type which provides in-place initialization +//! via `RefCell::init`. +//! +//! NOTE: TEMPORARY and subject to removal once all Matter state is hidden behind +//! a `blmutex::Mutex` in future. + +#![allow(unexpected_cfgs)] +#![allow(clippy::should_implement_trait)] + +use core::cell::{Cell, UnsafeCell}; +use core::cmp::Ordering; +use core::fmt::{self, Debug, Display}; +use core::marker::PhantomData; +use core::mem; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; + +use super::init::{init, Init, UnsafeCellInit}; + +/// A mutable memory location with dynamically checked borrow rules +/// +/// See the [module-level documentation](self) for more. +pub struct RefCell { + borrow: Cell, + // Stores the location of the earliest currently active borrow. + // This gets updated whenever we go from having zero borrows + // to having a single borrow. When a borrow occurs, this gets included + // in the generated `BorrowError`/`BorrowMutError` + #[cfg(feature = "debug_refcell")] + borrowed_at: Cell>>, + _not_sync: PhantomData<*const ()>, + value: UnsafeCell, +} + +/// An error returned by [`RefCell::try_borrow`]. +#[non_exhaustive] +pub struct BorrowError { + #[cfg(feature = "debug_refcell")] + location: &'static crate::panic::Location<'static>, +} + +impl Debug for BorrowError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("BorrowError"); + + #[cfg(feature = "debug_refcell")] + builder.field("location", self.location); + + builder.finish() + } +} + +impl Display for BorrowError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt("already mutably borrowed", f) + } +} + +/// An error returned by [`RefCell::try_borrow_mut`]. +#[non_exhaustive] +pub struct BorrowMutError { + #[cfg(feature = "debug_refcell")] + location: &'static crate::panic::Location<'static>, +} + +impl Debug for BorrowMutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("BorrowMutError"); + + #[cfg(feature = "debug_refcell")] + builder.field("location", self.location); + + builder.finish() + } +} + +impl Display for BorrowMutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt("already borrowed", f) + } +} + +// This ensures the panicking code is outlined from `borrow_mut` for `RefCell`. +#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never))] +#[track_caller] +#[cold] +fn panic_already_borrowed(err: BorrowMutError) -> ! { + panic!("already borrowed: {:?}", err) +} + +// This ensures the panicking code is outlined from `borrow` for `RefCell`. +#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never))] +#[track_caller] +#[cold] +fn panic_already_mutably_borrowed(err: BorrowError) -> ! { + panic!("already mutably borrowed: {:?}", err) +} + +// Positive values represent the number of `Ref` active. Negative values +// represent the number of `RefMut` active. Multiple `RefMut`s can only be +// active at a time if they refer to distinct, nonoverlapping components of a +// `RefCell` (e.g., different ranges of a slice). +// +// `Ref` and `RefMut` are both two words in size, and so there will likely never +// be enough `Ref`s or `RefMut`s in existence to overflow half of the `usize` +// range. Thus, a `BorrowFlag` will probably never overflow or underflow. +// However, this is not a guarantee, as a pathological program could repeatedly +// create and then mem::forget `Ref`s or `RefMut`s. Thus, all code must +// explicitly check for overflow and underflow in order to avoid unsafety, or at +// least behave correctly in the event that overflow or underflow happens (e.g., +// see BorrowRef::new). +type BorrowFlag = isize; +const UNUSED: BorrowFlag = 0; + +#[inline(always)] +fn is_writing(x: BorrowFlag) -> bool { + x < UNUSED +} + +#[inline(always)] +fn is_reading(x: BorrowFlag) -> bool { + x > UNUSED +} + +impl RefCell { + /// Creates a new `RefCell` containing `value`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// ``` + #[inline] + pub const fn new(value: T) -> RefCell { + RefCell { + value: UnsafeCell::new(value), + borrow: Cell::new(UNUSED), + #[cfg(feature = "debug_refcell")] + borrowed_at: Cell::new(None), + _not_sync: PhantomData, + } + } + + /// Creates a new `RefCell` in-place initializer + /// by using the given value initializer. + pub fn init>(value: I) -> impl Init { + init!(Self { + value <- UnsafeCell::init(value), + borrow: Cell::new(UNUSED), + // #[cfg(feature = "debug_refcell")] + // borrowed_at: Cell::new(None), + _not_sync: PhantomData, + }) + } + + /// Consumes the `RefCell`, returning the wrapped value. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let five = c.into_inner(); + /// ``` + #[inline] + pub fn into_inner(self) -> T { + // Since this function takes `self` (the `RefCell`) by value, the + // compiler statically verifies that it is not currently borrowed. + self.value.into_inner() + } + + /// Replaces the wrapped value with a new one, returning the old value, + /// without deinitializing either one. + /// + /// This function corresponds to [`std::mem::replace`](../mem/fn.replace.html). + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// let cell = RefCell::new(5); + /// let old_value = cell.replace(6); + /// assert_eq!(old_value, 5); + /// assert_eq!(cell, RefCell::new(6)); + /// ``` + #[inline] + #[track_caller] + pub fn replace(&self, t: T) -> T { + mem::replace(&mut *self.borrow_mut(), t) + } + + /// Replaces the wrapped value with a new one computed from `f`, returning + /// the old value, without deinitializing either one. + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// let cell = RefCell::new(5); + /// let old_value = cell.replace_with(|&mut old| old + 1); + /// assert_eq!(old_value, 5); + /// assert_eq!(cell, RefCell::new(6)); + /// ``` + #[inline] + #[track_caller] + pub fn replace_with T>(&self, f: F) -> T { + let mut_borrow = &mut *self.borrow_mut(); + let replacement = f(mut_borrow); + mem::replace(mut_borrow, replacement) + } + + /// Swaps the wrapped value of `self` with the wrapped value of `other`, + /// without deinitializing either one. + /// + /// This function corresponds to [`std::mem::swap`](../mem/fn.swap.html). + /// + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently borrowed, or + /// if `self` and `other` point to the same `RefCell`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// let c = RefCell::new(5); + /// let d = RefCell::new(6); + /// c.swap(&d); + /// assert_eq!(c, RefCell::new(6)); + /// assert_eq!(d, RefCell::new(5)); + /// ``` + #[inline] + pub fn swap(&self, other: &Self) { + mem::swap(&mut *self.borrow_mut(), &mut *other.borrow_mut()) + } +} + +impl RefCell { + /// Immutably borrows the wrapped value. + /// + /// The borrow lasts until the returned `Ref` exits scope. Multiple + /// immutable borrows can be taken out at the same time. + /// + /// # Panics + /// + /// Panics if the value is currently mutably borrowed. For a non-panicking variant, use + /// [`try_borrow`](#method.try_borrow). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let borrowed_five = c.borrow(); + /// let borrowed_five2 = c.borrow(); + /// ``` + /// + /// An example of panic: + /// + /// ```should_panic + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let m = c.borrow_mut(); + /// let b = c.borrow(); // this causes a panic + /// ``` + #[inline] + #[track_caller] + pub fn borrow(&self) -> Ref<'_, T> { + match self.try_borrow() { + Ok(b) => b, + Err(err) => panic_already_mutably_borrowed(err), + } + } + + /// Immutably borrows the wrapped value, returning an error if the value is currently mutably + /// borrowed. + /// + /// The borrow lasts until the returned `Ref` exits scope. Multiple immutable borrows can be + /// taken out at the same time. + /// + /// This is the non-panicking variant of [`borrow`](#method.borrow). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// { + /// let m = c.borrow_mut(); + /// assert!(c.try_borrow().is_err()); + /// } + /// + /// { + /// let m = c.borrow(); + /// assert!(c.try_borrow().is_ok()); + /// } + /// ``` + #[inline] + #[cfg_attr(feature = "debug_refcell", track_caller)] + pub fn try_borrow(&self) -> Result, BorrowError> { + match BorrowRef::new(&self.borrow) { + Some(b) => { + #[cfg(feature = "debug_refcell")] + { + // `borrowed_at` is always the *first* active borrow + if b.borrow.get() == 1 { + self.borrowed_at.set(Some(crate::panic::Location::caller())); + } + } + + // SAFETY: `BorrowRef` ensures that there is only immutable access + // to the value while borrowed. + let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + Ok(Ref { value, borrow: b }) + } + None => Err(BorrowError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(feature = "debug_refcell")] + location: self.borrowed_at.get().unwrap(), + }), + } + } + + /// Mutably borrows the wrapped value. + /// + /// The borrow lasts until the returned `RefMut` or all `RefMut`s derived + /// from it exit scope. The value cannot be borrowed while this borrow is + /// active. + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. For a non-panicking variant, use + /// [`try_borrow_mut`](#method.try_borrow_mut). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new("hello".to_owned()); + /// + /// *c.borrow_mut() = "bonjour".to_owned(); + /// + /// assert_eq!(&*c.borrow(), "bonjour"); + /// ``` + /// + /// An example of panic: + /// + /// ```should_panic + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// let m = c.borrow(); + /// + /// let b = c.borrow_mut(); // this causes a panic + /// ``` + #[inline] + #[track_caller] + pub fn borrow_mut(&self) -> RefMut<'_, T> { + match self.try_borrow_mut() { + Ok(b) => b, + Err(err) => panic_already_borrowed(err), + } + } + + /// Mutably borrows the wrapped value, returning an error if the value is currently borrowed. + /// + /// The borrow lasts until the returned `RefMut` or all `RefMut`s derived + /// from it exit scope. The value cannot be borrowed while this borrow is + /// active. + /// + /// This is the non-panicking variant of [`borrow_mut`](#method.borrow_mut). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// { + /// let m = c.borrow(); + /// assert!(c.try_borrow_mut().is_err()); + /// } + /// + /// assert!(c.try_borrow_mut().is_ok()); + /// ``` + #[inline] + #[cfg_attr(feature = "debug_refcell", track_caller)] + pub fn try_borrow_mut(&self) -> Result, BorrowMutError> { + match BorrowRefMut::new(&self.borrow) { + Some(b) => { + #[cfg(feature = "debug_refcell")] + { + self.borrowed_at.set(Some(crate::panic::Location::caller())); + } + + // SAFETY: `BorrowRefMut` guarantees unique access. + let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + Ok(RefMut { + value, + borrow: b, + marker: PhantomData, + }) + } + None => Err(BorrowMutError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(feature = "debug_refcell")] + location: self.borrowed_at.get().unwrap(), + }), + } + } + + /// Returns a raw pointer to the underlying data in this cell. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let ptr = c.as_ptr(); + /// ``` + #[inline] + //#[rustc_never_returns_null_ptr] + pub fn as_ptr(&self) -> *mut T { + self.value.get() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this method borrows `RefCell` mutably, it is statically guaranteed + /// that no borrows to the underlying data exist. The dynamic checks inherent + /// in [`borrow_mut`] and most other methods of `RefCell` are therefore + /// unnecessary. + /// + /// This method can only be called if `RefCell` can be mutably borrowed, + /// which in general is only the case directly after the `RefCell` has + /// been created. In these situations, skipping the aforementioned dynamic + /// borrowing checks may yield better ergonomics and runtime-performance. + /// + /// In most situations where `RefCell` is used, it can't be borrowed mutably. + /// Use [`borrow_mut`] to get mutable access to the underlying data then. + /// + /// [`borrow_mut`]: RefCell::borrow_mut() + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let mut c = RefCell::new(5); + /// *c.get_mut() += 1; + /// + /// assert_eq!(c, RefCell::new(6)); + /// ``` + #[inline] + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } + + /// Immutably borrows the wrapped value, returning an error if the value is + /// currently mutably borrowed. + /// + /// # Safety + /// + /// Unlike `RefCell::borrow`, this method is unsafe because it does not + /// return a `Ref`, thus leaving the borrow flag untouched. Mutably + /// borrowing the `RefCell` while the reference returned by this method + /// is alive is undefined behaviour. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// { + /// let m = c.borrow_mut(); + /// assert!(unsafe { c.try_borrow_unguarded() }.is_err()); + /// } + /// + /// { + /// let m = c.borrow(); + /// assert!(unsafe { c.try_borrow_unguarded() }.is_ok()); + /// } + /// ``` + #[inline] + pub unsafe fn try_borrow_unguarded(&self) -> Result<&T, BorrowError> { + if !is_writing(self.borrow.get()) { + // SAFETY: We check that nobody is actively writing now, but it is + // the caller's responsibility to ensure that nobody writes until + // the returned reference is no longer in use. + // Also, `self.value.get()` refers to the value owned by `self` + // and is thus guaranteed to be valid for the lifetime of `self`. + Ok(unsafe { &*self.value.get() }) + } else { + Err(BorrowError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(feature = "debug_refcell")] + location: self.borrowed_at.get().unwrap(), + }) + } + } +} + +impl RefCell { + /// Takes the wrapped value, leaving `Default::default()` in its place. + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// let five = c.take(); + /// + /// assert_eq!(five, 5); + /// assert_eq!(c.into_inner(), 0); + /// ``` + pub fn take(&self) -> T { + self.replace(Default::default()) + } +} + +unsafe impl Send for RefCell where T: Send {} + +impl Clone for RefCell { + /// # Panics + /// + /// Panics if the value is currently mutably borrowed. + #[inline] + #[track_caller] + fn clone(&self) -> RefCell { + RefCell::new(self.borrow().clone()) + } + + /// # Panics + /// + /// Panics if `other` is currently mutably borrowed. + #[inline] + #[track_caller] + fn clone_from(&mut self, other: &Self) { + self.get_mut().clone_from(&other.borrow()) + } +} + +impl Default for RefCell { + /// Creates a `RefCell`, with the `Default` value for T. + #[inline] + fn default() -> RefCell { + RefCell::new(Default::default()) + } +} + +impl PartialEq for RefCell { + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn eq(&self, other: &RefCell) -> bool { + *self.borrow() == *other.borrow() + } +} + +impl Eq for RefCell {} + +impl PartialOrd for RefCell { + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn partial_cmp(&self, other: &RefCell) -> Option { + self.borrow().partial_cmp(&*other.borrow()) + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn lt(&self, other: &RefCell) -> bool { + *self.borrow() < *other.borrow() + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn le(&self, other: &RefCell) -> bool { + *self.borrow() <= *other.borrow() + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn gt(&self, other: &RefCell) -> bool { + *self.borrow() > *other.borrow() + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn ge(&self, other: &RefCell) -> bool { + *self.borrow() >= *other.borrow() + } +} + +impl Ord for RefCell { + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn cmp(&self, other: &RefCell) -> Ordering { + self.borrow().cmp(&*other.borrow()) + } +} + +impl From for RefCell { + /// Creates a new `RefCell` containing the given value. + fn from(t: T) -> RefCell { + RefCell::new(t) + } +} + +struct BorrowRef<'b> { + borrow: &'b Cell, +} + +impl<'b> BorrowRef<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option> { + let b = borrow.get().wrapping_add(1); + if !is_reading(b) { + // Incrementing borrow can result in a non-reading value (<= 0) in these cases: + // 1. It was < 0, i.e. there are writing borrows, so we can't allow a read borrow + // due to Rust's reference aliasing rules + // 2. It was isize::MAX (the max amount of reading borrows) and it overflowed + // into isize::MIN (the max amount of writing borrows) so we can't allow + // an additional read borrow because isize can't represent so many read borrows + // (this can only happen if you mem::forget more than a small constant amount of + // `Ref`s, which is not good practice) + None + } else { + // Incrementing borrow can result in a reading value (> 0) in these cases: + // 1. It was = 0, i.e. it wasn't borrowed, and we are taking the first read borrow + // 2. It was > 0 and < isize::MAX, i.e. there were read borrows, and isize + // is large enough to represent having one more read borrow + borrow.set(b); + Some(BorrowRef { borrow }) + } + } +} + +impl Drop for BorrowRef<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(is_reading(borrow)); + self.borrow.set(borrow - 1); + } +} + +impl Clone for BorrowRef<'_> { + #[inline] + fn clone(&self) -> Self { + // Since this Ref exists, we know the borrow flag + // is a reading borrow. + let borrow = self.borrow.get(); + debug_assert!(is_reading(borrow)); + // Prevent the borrow counter from overflowing into + // a writing borrow. + assert!(borrow != BorrowFlag::MAX); + self.borrow.set(borrow + 1); + BorrowRef { + borrow: self.borrow, + } + } +} + +/// Wraps a borrowed reference to a value in a `RefCell` box. +/// A wrapper type for an immutably borrowed value from a `RefCell`. +/// +/// See the [module-level documentation](self) for more. +// #[must_not_suspend = "holding a Ref across suspend points can cause BorrowErrors"] +// #[rustc_diagnostic_item = "RefCellRef"] +pub struct Ref<'b, T: ?Sized + 'b> { + // NB: we use a pointer instead of `&'b T` to avoid `noalias` violations, because a + // `Ref` argument doesn't hold immutability for its whole scope, only until it drops. + // `NonNull` is also covariant over `T`, just like we would have with `&T`. + value: NonNull, + borrow: BorrowRef<'b>, +} + +impl Deref for Ref<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} + +impl<'b, T: ?Sized> Ref<'b, T> { + /// Copies a `Ref`. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `Ref::clone(...)`. A `Clone` implementation or a method would interfere + /// with the widespread use of `r.borrow().clone()` to clone the contents of + /// a `RefCell`. + #[must_use] + #[inline] + pub fn clone(orig: &Ref<'b, T>) -> Ref<'b, T> { + Ref { + value: orig.value, + borrow: orig.borrow.clone(), + } + } + + /// Makes a new `Ref` for a component of the borrowed data. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as `Ref::map(...)`. + /// A method would interfere with methods of the same name on the contents + /// of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, Ref}; + /// + /// let c = RefCell::new((5, 'b')); + /// let b1: Ref<'_, (u32, char)> = c.borrow(); + /// let b2: Ref<'_, u32> = Ref::map(b1, |t| &t.0); + /// assert_eq!(*b2, 5) + /// ``` + #[inline] + pub fn map(orig: Ref<'b, T>, f: F) -> Ref<'b, U> + where + F: FnOnce(&T) -> &U, + { + Ref { + value: NonNull::from(f(&*orig)), + borrow: orig.borrow, + } + } + + /// Makes a new `Ref` for an optional component of the borrowed data. The + /// original guard is returned as an `Err(..)` if the closure returns + /// `None`. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `Ref::filter_map(...)`. A method would interfere with methods of the same + /// name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, Ref}; + /// + /// let c = RefCell::new(vec![1, 2, 3]); + /// let b1: Ref<'_, Vec> = c.borrow(); + /// let b2: Result, _> = Ref::filter_map(b1, |v| v.get(1)); + /// assert_eq!(*b2.unwrap(), 2); + /// ``` + #[inline] + pub fn filter_map(orig: Ref<'b, T>, f: F) -> Result, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + match f(&*orig) { + Some(value) => Ok(Ref { + value: NonNull::from(value), + borrow: orig.borrow, + }), + None => Err(orig), + } + } + + /// Splits a `Ref` into multiple `Ref`s for different components of the + /// borrowed data. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `Ref::map_split(...)`. A method would interfere with methods of the same + /// name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{Ref, RefCell}; + /// + /// let cell = RefCell::new([1, 2, 3, 4]); + /// let borrow = cell.borrow(); + /// let (begin, end) = Ref::map_split(borrow, |slice| slice.split_at(2)); + /// assert_eq!(*begin, [1, 2]); + /// assert_eq!(*end, [3, 4]); + /// ``` + #[inline] + pub fn map_split(orig: Ref<'b, T>, f: F) -> (Ref<'b, U>, Ref<'b, V>) + where + F: FnOnce(&T) -> (&U, &V), + { + let (a, b) = f(&*orig); + let borrow = orig.borrow.clone(); + ( + Ref { + value: NonNull::from(a), + borrow, + }, + Ref { + value: NonNull::from(b), + borrow: orig.borrow, + }, + ) + } +} + +impl fmt::Display for Ref<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl<'b, T: ?Sized> RefMut<'b, T> { + /// Makes a new `RefMut` for a component of the borrowed data, e.g., an enum + /// variant. + /// + /// The `RefCell` is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `RefMut::map(...)`. A method would interfere with methods of the same + /// name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, RefMut}; + /// + /// let c = RefCell::new((5, 'b')); + /// { + /// let b1: RefMut<'_, (u32, char)> = c.borrow_mut(); + /// let mut b2: RefMut<'_, u32> = RefMut::map(b1, |t| &mut t.0); + /// assert_eq!(*b2, 5); + /// *b2 = 42; + /// } + /// assert_eq!(*c.borrow(), (42, 'b')); + /// ``` + #[inline] + pub fn map(mut orig: RefMut<'b, T>, f: F) -> RefMut<'b, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let value = NonNull::from(f(&mut *orig)); + RefMut { + value, + borrow: orig.borrow, + marker: PhantomData, + } + } + + /// Makes a new `RefMut` for an optional component of the borrowed data. The + /// original guard is returned as an `Err(..)` if the closure returns + /// `None`. + /// + /// The `RefCell` is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `RefMut::filter_map(...)`. A method would interfere with methods of the + /// same name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, RefMut}; + /// + /// let c = RefCell::new(vec![1, 2, 3]); + /// + /// { + /// let b1: RefMut<'_, Vec> = c.borrow_mut(); + /// let mut b2: Result, _> = RefMut::filter_map(b1, |v| v.get_mut(1)); + /// + /// if let Ok(mut b2) = b2 { + /// *b2 += 2; + /// } + /// } + /// + /// assert_eq!(*c.borrow(), vec![1, 4, 3]); + /// ``` + #[inline] + pub fn filter_map(mut orig: RefMut<'b, T>, f: F) -> Result, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + // SAFETY: function holds onto an exclusive reference for the duration + // of its call through `orig`, and the pointer is only de-referenced + // inside of the function call never allowing the exclusive reference to + // escape. + match f(&mut *orig) { + Some(value) => Ok(RefMut { + value: NonNull::from(value), + borrow: orig.borrow, + marker: PhantomData, + }), + None => Err(orig), + } + } + + /// Splits a `RefMut` into multiple `RefMut`s for different components of the + /// borrowed data. + /// + /// The underlying `RefCell` will remain mutably borrowed until both + /// returned `RefMut`s go out of scope. + /// + /// The `RefCell` is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `RefMut::map_split(...)`. A method would interfere with methods of the + /// same name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, RefMut}; + /// + /// let cell = RefCell::new([1, 2, 3, 4]); + /// let borrow = cell.borrow_mut(); + /// let (mut begin, mut end) = RefMut::map_split(borrow, |slice| slice.split_at_mut(2)); + /// assert_eq!(*begin, [1, 2]); + /// assert_eq!(*end, [3, 4]); + /// begin.copy_from_slice(&[4, 3]); + /// end.copy_from_slice(&[2, 1]); + /// ``` + #[inline] + pub fn map_split( + mut orig: RefMut<'b, T>, + f: F, + ) -> (RefMut<'b, U>, RefMut<'b, V>) + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + let borrow = orig.borrow.clone(); + let (a, b) = f(&mut *orig); + ( + RefMut { + value: NonNull::from(a), + borrow, + marker: PhantomData, + }, + RefMut { + value: NonNull::from(b), + borrow: orig.borrow, + marker: PhantomData, + }, + ) + } +} + +struct BorrowRefMut<'b> { + borrow: &'b Cell, +} + +impl Drop for BorrowRefMut<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(is_writing(borrow)); + self.borrow.set(borrow + 1); + } +} + +impl<'b> BorrowRefMut<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option> { + // NOTE: Unlike BorrowRefMut::clone, new is called to create the initial + // mutable reference, and so there must currently be no existing + // references. Thus, while clone increments the mutable refcount, here + // we explicitly only allow going from UNUSED to UNUSED - 1. + match borrow.get() { + UNUSED => { + borrow.set(UNUSED - 1); + Some(BorrowRefMut { borrow }) + } + _ => None, + } + } + + // Clones a `BorrowRefMut`. + // + // This is only valid if each `BorrowRefMut` is used to track a mutable + // reference to a distinct, nonoverlapping range of the original object. + // This isn't in a Clone impl so that code doesn't call this implicitly. + #[inline] + fn clone(&self) -> BorrowRefMut<'b> { + let borrow = self.borrow.get(); + debug_assert!(is_writing(borrow)); + // Prevent the borrow counter from underflowing. + assert!(borrow != BorrowFlag::MIN); + self.borrow.set(borrow - 1); + BorrowRefMut { + borrow: self.borrow, + } + } +} + +/// A wrapper type for a mutably borrowed value from a `RefCell`. +/// +/// See the [module-level documentation](self) for more. +// #[must_not_suspend = "holding a RefMut across suspend points can cause BorrowErrors"] +// #[rustc_diagnostic_item = "RefCellRefMut"] +pub struct RefMut<'b, T: ?Sized + 'b> { + // NB: we use a pointer instead of `&'b mut T` to avoid `noalias` violations, because a + // `RefMut` argument doesn't hold exclusivity for its whole scope, only until it drops. + value: NonNull, + borrow: BorrowRefMut<'b>, + // `NonNull` is covariant over `T`, so we need to reintroduce invariance. + marker: PhantomData<&'b mut T>, +} + +impl Deref for RefMut<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} + +impl DerefMut for RefMut<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_mut() } + } +} + +impl fmt::Display for RefMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} diff --git a/rs-matter/src/utils/epoch.rs b/rs-matter/src/utils/epoch.rs index 630f8783..17ad4a2b 100644 --- a/rs-matter/src/utils/epoch.rs +++ b/rs-matter/src/utils/epoch.rs @@ -1,3 +1,20 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + use core::time::Duration; pub type Epoch = fn() -> Duration; diff --git a/rs-matter/src/utils/init.rs b/rs-matter/src/utils/init.rs new file mode 100644 index 00000000..e3bedfc1 --- /dev/null +++ b/rs-matter/src/utils/init.rs @@ -0,0 +1,91 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::convert::Infallible; +use core::{cell::UnsafeCell, mem::MaybeUninit}; + +/// Re-export `pinned-init` because its API is very unstable currently (0.0.x) +pub use pinned_init::*; + +/// An extension trait for converting `Init` to a fallible `Init`. +/// Useful when chaining an infallible initializer with a fallible chained initialization function. +pub trait IntoFallibleInit: Init { + /// Convert the infallible initializer to a fallible one. + fn into_fallible(self) -> impl Init { + unsafe { + init_from_closure(move |slot| { + Self::__init(self, slot).unwrap(); + + Ok(()) + }) + } + } +} + +impl IntoFallibleInit for I where I: Init {} + +/// An extension trait for re-setting an already instantiated `T` with the given initializer. +pub trait ApplyInit: Init { + fn apply(self, to: &mut T) -> Result<(), E> { + unsafe { Self::__init(self, to as *mut T) } + } +} + +impl ApplyInit for I where I: Init {} +/// An extension trait for retrofitting `UnsafeCell` with an initializer. +pub trait UnsafeCellInit { + /// Create a new in-place initializer for `UnsafeCell` + /// by using the given initializer for the value. + fn init>(value: I) -> impl Init; +} + +impl UnsafeCellInit for UnsafeCell { + fn init>(value: I) -> impl Init { + unsafe { + init_from_closure::<_, Infallible>(move |slot: *mut Self| { + // `slot` contains uninit memory, avoid creating a reference. + let slot: *mut T = slot as _; + + // Initialize the value + value.__init(slot).unwrap(); + + Ok(()) + }) + } + } +} + +/// An extension trait that allows safe initialization of +/// `MaybeUninit` memory. +pub trait InitMaybeUninit { + /// Initialize Self with the given in-place initializer. + fn init_with>(&mut self, init: I) -> &mut T { + self.try_init_with(init).unwrap() + } + + fn try_init_with, E>(&mut self, init: I) -> Result<&mut T, E>; +} + +impl InitMaybeUninit for MaybeUninit { + fn try_init_with, E>(&mut self, init: I) -> Result<&mut T, E> { + unsafe { + Init::::__init(init, self.as_mut_ptr())?; + + Ok(self.assume_init_mut()) + } + } +} diff --git a/rs-matter/src/utils/maybe.rs b/rs-matter/src/utils/maybe.rs new file mode 100644 index 00000000..b79cf206 --- /dev/null +++ b/rs-matter/src/utils/maybe.rs @@ -0,0 +1,215 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::fmt::Debug; +use core::hash::Hash; +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::ptr::addr_of_mut; + +use super::init; + +/// Represents a type similar in spirit to the built-in `Option` type. +/// Unlike `Option` however, `Maybe` _does_ have in-place initializer support. +/// +/// (In-place initializer support is impossible to provide for `Option` due to its +/// enum nature, and because it is not marked with `repr(transparent)`). +/// +/// `Maybe` is convertable to and from `Option` (via the `From` / `Into` traits), +/// however these conversions are not recommended when the wrapped value is large +/// which defeats the purpose of using `Maybe` in the first place. +/// +/// The canonical way to use `Maybe` with large values is to initialize it in-place with +/// one of the provided init constructors, and then use one of the `as_ref`, `as_mut`, +/// `as_deref` and `as_deref_mut` methods to access the wrapped value. +#[derive(Debug)] +pub struct Maybe { + some: bool, + value: MaybeUninit, + _tag: PhantomData, +} + +impl Maybe { + /// Create a new `Maybe` value from an `Option`. + /// + /// Note that when the wrapped value is large, it is recommended instead to use + /// `Maybe::init_none()` and `Maybe::init_some()` to create the `Maybe` value in-place. + pub fn new(value: Option) -> Self { + match value { + Some(v) => Self::some(v), + None => Self::none(), + } + } + + /// Create a new, empty `Maybe` value. + pub const fn none() -> Self { + Self { + some: false, + value: MaybeUninit::uninit(), + _tag: PhantomData, + } + } + + /// Create a new `Maybe` value with a wrapped value. + pub const fn some(value: T) -> Self { + Self { + some: true, + value: MaybeUninit::new(value), + _tag: PhantomData, + } + } + + /// Create an in-place initializer for a `Maybe` value that is empty. + pub fn init_none, E>() -> impl init::Init { + Self::init::(None) + } + + /// Create an in-place initializer for a `Maybe` value that is not empty + /// by initializing the wrapped value with the provided initializer. + pub fn init_some, E>(value: I) -> impl init::Init { + Self::init(Some(value)) + } + + /// Create an in-place initializer for a `Maybe` value that might or might + /// not be empty. + pub fn init, E>(value: Option) -> impl init::Init { + unsafe { + init::init_from_closure(move |slot: *mut Self| { + addr_of_mut!((*slot).some).write(value.is_some()); + + if let Some(value) = value { + value.__init(addr_of_mut!((*slot).value) as _)?; + } + + Ok(()) + }) + } + } + + /// Return a mutable reference to the wrapped value, if it exists. + pub fn as_mut(&mut self) -> Option<&mut T> { + if self.some { + Some(unsafe { self.value.assume_init_mut() }) + } else { + None + } + } + + /// Return a reference to the wrapped value, if it exists. + pub fn as_ref(&self) -> Option<&T> { + if self.some { + Some(unsafe { self.value.assume_init_ref() }) + } else { + None + } + } + + /// Derefs the wrapped value, if it exists. + pub fn as_deref(&self) -> Option<&T::Target> + where + T: Deref, + { + match self.as_ref() { + Some(t) => Some(t.deref()), + None => None, + } + } + + /// Derefs mutably the wrapped value, if it exists. + pub fn as_deref_mut(&mut self) -> Option<&mut T::Target> + where + T: DerefMut, + { + match self.as_mut() { + Some(t) => Some(t.deref_mut()), + None => None, + } + } + + /// Consume the `Maybe` value and return the wrapped value, if it exists. + /// + /// Note that this method is not efficient when the wrapped value is large + /// (might result in big stack memory usage due to moves), hence its usage + /// is not recommended when the wrapped value is large. + pub fn into_option(self) -> Option { + if self.some { + Some(unsafe { self.value.assume_init() }) + } else { + None + } + } + + /// Return whether the `Maybe` value is empty. + pub fn is_none(&self) -> bool { + !self.some + } + + /// Return whether the `Maybe` value is not empty. + pub fn is_some(&self) -> bool { + self.some + } +} + +impl Default for Maybe { + fn default() -> Self { + Self::none() + } +} + +impl From> for Maybe { + fn from(value: Option) -> Self { + Self::new(value) + } +} + +impl From> for Option { + fn from(value: Maybe) -> Self { + value.into_option() + } +} + +impl Clone for Maybe +where + T: Clone, +{ + fn clone(&self) -> Self { + Maybe::<_, G>::new(self.as_ref().cloned()) + } +} + +impl Copy for Maybe where T: Copy {} + +impl PartialEq for Maybe +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.as_ref() == other.as_ref() + } +} + +impl Eq for Maybe where T: Eq {} + +impl Hash for Maybe +where + T: Hash, +{ + fn hash(&self, state: &mut H) { + self.as_ref().hash(state) + } +} diff --git a/rs-matter/src/utils/mod.rs b/rs-matter/src/utils/mod.rs index b7c0136c..da766956 100644 --- a/rs-matter/src/utils/mod.rs +++ b/rs-matter/src/utils/mod.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -pub mod buf; +pub mod cell; pub mod epoch; -pub mod ifmutex; -pub mod notification; -pub mod parsebuf; +pub mod init; +pub mod maybe; pub mod rand; -pub mod ringbuf; pub mod select; -pub mod signal; -pub mod std_mutex; -pub mod writebuf; +pub mod storage; +pub mod sync; diff --git a/rs-matter/src/utils/std_mutex.rs b/rs-matter/src/utils/storage.rs similarity index 50% rename from rs-matter/src/utils/std_mutex.rs rename to rs-matter/src/utils/storage.rs index 868e2339..a8c24904 100644 --- a/rs-matter/src/utils/std_mutex.rs +++ b/rs-matter/src/utils/storage.rs @@ -15,28 +15,14 @@ * limitations under the License. */ -#![cfg(feature = "std")] +pub use parsebuf::*; +pub use ringbuf::*; +pub use vec::*; +pub use writebuf::*; -use embassy_sync::blocking_mutex::raw::RawMutex; +pub mod pooled; -/// An `embassy-sync` `RawMutex` implementation using `std::sync::Mutex`. -/// TODO: Upstream into `embassy-sync` itself. -#[derive(Default)] -pub struct StdRawMutex(std::sync::Mutex<()>); - -impl StdRawMutex { - pub const fn new() -> Self { - Self(std::sync::Mutex::new(())) - } -} - -unsafe impl RawMutex for StdRawMutex { - #[allow(clippy::declare_interior_mutable_const)] - const INIT: Self = StdRawMutex(std::sync::Mutex::new(())); - - fn lock(&self, f: impl FnOnce() -> R) -> R { - let _guard = self.0.lock().unwrap(); - - f() - } -} +mod parsebuf; +mod ringbuf; +mod vec; +mod writebuf; diff --git a/rs-matter/src/utils/parsebuf.rs b/rs-matter/src/utils/storage/parsebuf.rs similarity index 99% rename from rs-matter/src/utils/parsebuf.rs rename to rs-matter/src/utils/storage/parsebuf.rs index d96c416f..3bc26685 100644 --- a/rs-matter/src/utils/parsebuf.rs +++ b/rs-matter/src/utils/storage/parsebuf.rs @@ -120,7 +120,7 @@ impl<'a> ParseBuf<'a> { #[cfg(test)] mod tests { - use crate::utils::parsebuf::*; + use crate::utils::storage::ParseBuf; #[test] fn test_parse_with_success() { diff --git a/rs-matter/src/utils/buf.rs b/rs-matter/src/utils/storage/pooled.rs similarity index 85% rename from rs-matter/src/utils/buf.rs rename to rs-matter/src/utils/storage/pooled.rs index 898f3bd5..a980902d 100644 --- a/rs-matter/src/utils/buf.rs +++ b/rs-matter/src/utils/storage/pooled.rs @@ -23,7 +23,8 @@ use embassy_futures::select::{select, Either}; use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_time::{Duration, Timer}; -use super::signal::Signal; +use crate::utils::init::{init, Init, UnsafeCellInit}; +use crate::utils::sync::Signal; /// A trait for getting access to a `&mut T` buffer, potentially awaiting until a buffer becomes available. pub trait BufferAccess @@ -59,7 +60,7 @@ where /// Accessing a buffer would fail when all buffers are still used elsewhere after a wait timeout expires. pub struct PooledBuffers { available: Signal, - pool: UnsafeCell>, + pool: UnsafeCell>, wait_timeout_ms: u32, } @@ -67,14 +68,30 @@ impl PooledBuffers where M: RawMutex, { + /// Create a new instance of `PooledBuffers`. + /// + /// `wait_timneout_ms` is the maximum time to wait for a buffer to become available + /// before returning `None`. #[inline(always)] pub const fn new(wait_timeout_ms: u32) -> Self { Self { available: Signal::new([true; N]), - pool: UnsafeCell::new(heapless::Vec::new()), + pool: UnsafeCell::new(crate::utils::storage::Vec::new()), wait_timeout_ms, } } + + /// Create an in-place initializer for `PooledBuffers`. + /// + /// `wait_timneout_ms` is the maximum time to wait for a buffer to become available + /// before returning `None`. + pub fn init(wait_timeout_ms: u32) -> impl Init { + init!(Self { + available: Signal::new([true; N]), + pool <- UnsafeCell::init(crate::utils::storage::Vec::init()), + wait_timeout_ms, + }) + } } impl BufferAccess for PooledBuffers diff --git a/rs-matter/src/utils/ringbuf.rs b/rs-matter/src/utils/storage/ringbuf.rs similarity index 87% rename from rs-matter/src/utils/ringbuf.rs rename to rs-matter/src/utils/storage/ringbuf.rs index a02f8264..2cefeb81 100644 --- a/rs-matter/src/utils/ringbuf.rs +++ b/rs-matter/src/utils/storage/ringbuf.rs @@ -17,13 +17,15 @@ use core::cmp::min; +use crate::utils::init::{init, Init}; + /// A ring buffer of a fixed capacity `N` using owned storage. #[derive(Debug)] pub struct RingBuf { - buf: heapless::Vec, + buf: crate::utils::storage::Vec, start: usize, end: usize, - empty: bool, + non_empty: bool, } impl Default for RingBuf { @@ -37,13 +39,23 @@ impl RingBuf { #[inline(always)] pub const fn new() -> Self { Self { - buf: heapless::Vec::new(), + buf: crate::utils::storage::Vec::new(), start: 0, end: 0, - empty: true, + non_empty: false, } } + /// Create an in-place initializer for the ring buffer. + pub fn init() -> impl Init { + init!(Self { + buf <- crate::utils::storage::Vec::init(), + start: 0, + end: 0, + non_empty: false, + }) + } + /// Push new data to the end of the buffer. /// If the data does not fit in the buffer, the oldest data is dropped to make room for the new one. /// @@ -62,7 +74,7 @@ impl RingBuf { offset += len; - if !self.empty && self.start >= self.end && self.start < self.end + len { + if self.non_empty && self.start >= self.end && self.start < self.end + len { // Dropping oldest data self.start = self.end + len; } @@ -71,7 +83,7 @@ impl RingBuf { self.wrap(); - self.empty = false; + self.non_empty = true; } self.len() @@ -88,7 +100,7 @@ impl RingBuf { self.buf[self.end] = data; - if !self.empty && self.start == self.end { + if self.non_empty && self.start == self.end { // Dropping oldest data self.start = self.end + 1; } @@ -97,7 +109,7 @@ impl RingBuf { self.wrap(); - self.empty = false; + self.non_empty = true; self.len() } @@ -121,7 +133,7 @@ impl RingBuf { pub fn pop(&mut self, out_buf: &mut [u8]) -> usize { let mut offset = 0; - while offset < out_buf.len() && !self.empty { + while offset < out_buf.len() && self.non_empty { let len = min( if self.start < self.end { self.end @@ -138,7 +150,7 @@ impl RingBuf { self.wrap(); if self.start == self.end { - self.empty = true + self.non_empty = false } offset += len; @@ -150,20 +162,20 @@ impl RingBuf { /// Return `true` when the buffer is full. #[inline(always)] pub fn is_full(&self) -> bool { - self.start == self.end && !self.empty + self.start == self.end && self.non_empty } /// Return `true` when the buffer is empty. #[inline(always)] pub fn is_empty(&self) -> bool { - self.empty + !self.non_empty } /// Return the current size of the data in the buffer. #[inline(always)] #[allow(unused)] pub fn len(&self) -> usize { - if self.empty { + if !self.non_empty { 0 } else if self.start < self.end { self.end - self.start @@ -184,7 +196,7 @@ impl RingBuf { pub fn clear(&mut self) { self.start = 0; self.end = 0; - self.empty = true; + self.non_empty = false; } #[inline(always)] diff --git a/rs-matter/src/utils/storage/vec.rs b/rs-matter/src/utils/storage/vec.rs new file mode 100644 index 00000000..49e85d80 --- /dev/null +++ b/rs-matter/src/utils/storage/vec.rs @@ -0,0 +1,1691 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! A modification of `heapless::Vec` that provides the following extra features: +//! - In-place initialization of the vec itself with `Vec::init() -> impl Init` +//! - In-place initialization of the vec members with `Vec::push_init(init: I) -> Result<(), ()>` + +#![allow(clippy::unnecessary_cast)] +#![allow(clippy::redundant_slicing)] +#![allow(clippy::result_unit_err)] +#![allow(clippy::should_implement_trait)] + +use core::{ + cmp::Ordering, + fmt, hash, + iter::FromIterator, + mem::MaybeUninit, + ops, + ptr::{self, addr_of_mut}, + slice, +}; + +use crate::utils::init::{init_from_closure, Init}; + +/// A fixed capacity [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html) +/// +/// # Examples +/// +/// ``` +/// use heapless::Vec; +/// +/// +/// // A vector with a fixed capacity of 8 elements allocated on the stack +/// let mut vec = Vec::<_, 8>::new(); +/// vec.push(1); +/// vec.push(2); +/// +/// assert_eq!(vec.len(), 2); +/// assert_eq!(vec[0], 1); +/// +/// assert_eq!(vec.pop(), Some(2)); +/// assert_eq!(vec.len(), 1); +/// +/// vec[0] = 7; +/// assert_eq!(vec[0], 7); +/// +/// vec.extend([1, 2, 3].iter().cloned()); +/// +/// for x in &vec { +/// println!("{}", x); +/// } +/// assert_eq!(*vec, [7, 1, 2, 3]); +/// ``` +pub struct Vec { + // NOTE order is important for optimizations. the `len` first layout lets the compiler optimize + // `new` to: reserve stack space and zero the first word. With the fields in the reverse order + // the compiler optimizes `new` to `memclr`-ing the *entire* stack space, including the `buffer` + // field which should be left uninitialized. Optimizations were last checked with Rust 1.60 + len: usize, + + buffer: [MaybeUninit; N], +} + +impl Vec { + const ELEM: MaybeUninit = MaybeUninit::uninit(); + const INIT: [MaybeUninit; N] = [Self::ELEM; N]; // important for optimization of `new` + + /// Constructs a new, empty vector with a fixed capacity of `N` + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// // allocate the vector on the stack + /// let mut x: Vec = Vec::new(); + /// + /// // allocate the vector in a static variable + /// static mut X: Vec = Vec::new(); + /// ``` + /// `Vec` `const` constructor; wrap the returned value in [`Vec`]. + pub const fn new() -> Self { + Self { + len: 0, + buffer: Self::INIT, + } + } + + /// Returns an in-place initializer for a new, empty vector. + pub fn init() -> impl Init { + unsafe { + init_from_closure(move |slot: *mut Self| { + addr_of_mut!((*slot).len).write(0); + + Ok(()) + }) + } + } + + /// Constructs a new vector with a fixed capacity of `N` and fills it + /// with the provided slice. + /// + /// This is equivalent to the following code: + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut v: Vec = Vec::new(); + /// v.extend_from_slice(&[1, 2, 3]).unwrap(); + /// ``` + #[inline] + pub fn from_slice(other: &[T]) -> Result + where + T: Clone, + { + let mut v = Vec::new(); + v.extend_from_slice(other)?; + Ok(v) + } + + /// Clones a vec into a new vec + pub(crate) fn clone(&self) -> Self + where + T: Clone, + { + let mut new = Self::new(); + // avoid `extend_from_slice` as that introduces a runtime check / panicking branch + for elem in self { + unsafe { + new.push_unchecked(elem.clone()); + } + } + new + } + + /// Returns a raw pointer to the vector’s buffer. + pub fn as_ptr(&self) -> *const T { + self.buffer.as_ptr() as *const T + } + + /// Returns a raw pointer to the vector’s buffer, which may be mutated through. + pub fn as_mut_ptr(&mut self) -> *mut T { + self.buffer.as_mut_ptr() as *mut T + } + + /// Extracts a slice containing the entire vector. + /// + /// Equivalent to `&s[..]`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// let buffer: Vec = Vec::from_slice(&[1, 2, 3, 5, 8]).unwrap(); + /// assert_eq!(buffer.as_slice(), &[1, 2, 3, 5, 8]); + /// ``` + pub fn as_slice(&self) -> &[T] { + // NOTE(unsafe) avoid bound checks in the slicing operation + // &buffer[..self.len] + unsafe { slice::from_raw_parts(self.buffer.as_ptr() as *const T, self.len) } + } + + /// Returns the contents of the vector as an array of length `M` if the length + /// of the vector is exactly `M`, otherwise returns `Err(self)`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// let buffer: Vec = Vec::from_slice(&[1, 2, 3, 5, 8]).unwrap(); + /// let array: [u8; 5] = buffer.into_array().unwrap(); + /// assert_eq!(array, [1, 2, 3, 5, 8]); + /// ``` + pub fn into_array(self) -> Result<[T; M], Self> { + if self.len() == M { + // This is how the unstable `MaybeUninit::array_assume_init` method does it + let array = unsafe { (&self.buffer as *const _ as *const [T; M]).read() }; + + // We don't want `self`'s destructor to be called because that would drop all the + // items in the array + core::mem::forget(self); + + Ok(array) + } else { + Err(self) + } + } + + /// Extracts a mutable slice containing the entire vector. + /// + /// Equivalent to `&mut s[..]`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// let mut buffer: Vec = Vec::from_slice(&[1, 2, 3, 5, 8]).unwrap(); + /// buffer[0] = 9; + /// assert_eq!(buffer.as_slice(), &[9, 2, 3, 5, 8]); + /// ``` + pub fn as_mut_slice(&mut self) -> &mut [T] { + // NOTE(unsafe) avoid bound checks in the slicing operation + // &mut buffer[..self.len] + unsafe { slice::from_raw_parts_mut(self.buffer.as_mut_ptr() as *mut T, self.len) } + } + + /// Returns the maximum number of elements the vector can hold. + pub const fn capacity(&self) -> usize { + N + } + + /// Clears the vector, removing all values. + pub fn clear(&mut self) { + self.truncate(0); + } + + /// Extends the vec from an iterator. + /// + /// # Panic + /// + /// Panics if the vec cannot hold all elements of the iterator. + pub fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + for elem in iter { + self.push(elem).ok().unwrap() + } + } + + /// Clones and appends all elements in a slice to the `Vec`. + /// + /// Iterates over the slice `other`, clones each element, and then appends + /// it to this `Vec`. The `other` vector is traversed in-order. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec = Vec::::new(); + /// vec.push(1).unwrap(); + /// vec.extend_from_slice(&[2, 3, 4]).unwrap(); + /// assert_eq!(*vec, [1, 2, 3, 4]); + /// ``` + pub fn extend_from_slice(&mut self, other: &[T]) -> Result<(), ()> + where + T: Clone, + { + if self.len + other.len() > self.capacity() { + // won't fit in the `Vec`; don't modify anything and return an error + Err(()) + } else { + for elem in other { + unsafe { + self.push_unchecked(elem.clone()); + } + } + Ok(()) + } + } + + /// Removes the last element from a vector and returns it, or `None` if it's empty + pub fn pop(&mut self) -> Option { + if self.len != 0 { + Some(unsafe { self.pop_unchecked() }) + } else { + None + } + } + + /// Appends an `item` to the back of the collection + /// + /// Returns back the `item` if the vector is full + pub fn push(&mut self, item: T) -> Result<(), T> { + if self.len < self.capacity() { + unsafe { self.push_unchecked(item) } + Ok(()) + } else { + Err(item) + } + } + + /// Appends an item with the provided item initializer - `init` + /// to the back of the collection + /// + /// Returns an error generated by `f` if the vector is full + pub fn push_init, E, F: FnOnce() -> E>( + &mut self, + init: I, + f: F, + ) -> Result<(), E> { + if self.len < self.capacity() { + self.push_init_unchecked(init) + } else { + Err(f()) + } + } + + /// Removes the last element from a vector and returns it + /// + /// # Safety + /// + /// This assumes the vec to have at least one element. + pub unsafe fn pop_unchecked(&mut self) -> T { + debug_assert!(!self.is_empty()); + + self.len -= 1; + (self.buffer.get_unchecked_mut(self.len).as_ptr() as *const T).read() + } + + /// Appends an `item` to the back of the collection + /// + /// # Safety + /// + /// This assumes the vec is not full. + pub unsafe fn push_unchecked(&mut self, item: T) { + // NOTE(ptr::write) the memory slot that we are about to write to is uninitialized. We + // use `ptr::write` to avoid running `T`'s destructor on the uninitialized memory + debug_assert!(!self.is_full()); + + *self.buffer.get_unchecked_mut(self.len) = MaybeUninit::new(item); + + self.len += 1; + } + + /// Appends an item with the provided item initializer - `init` + /// to the back of the collection + /// + /// Panics if the vec is full. + pub fn push_init_unchecked, E>(&mut self, init: I) -> Result<(), E> { + if self.is_full() { + panic!("Vec::push_init_unchecked: vec is full"); + } + + unsafe { + // NOTE(ptr::write) the memory slot that we are about to write to is uninitialized. We + // use `ptr::write` to avoid running `T`'s destructor on the uninitialized memory + let buffer: *mut T = self.buffer.as_mut_ptr().add(self.len) as _; + + init.__init(buffer)?; + } + + self.len += 1; + + Ok(()) + } + + /// Shortens the vector, keeping the first `len` elements and dropping the rest. + pub fn truncate(&mut self, len: usize) { + // This is safe because: + // + // * the slice passed to `drop_in_place` is valid; the `len > self.len` + // case avoids creating an invalid slice, and + // * the `len` of the vector is shrunk before calling `drop_in_place`, + // such that no value will be dropped twice in case `drop_in_place` + // were to panic once (if it panics twice, the program aborts). + unsafe { + // Note: It's intentional that this is `>` and not `>=`. + // Changing it to `>=` has negative performance + // implications in some cases. See rust-lang/rust#78884 for more. + if len > self.len { + return; + } + let remaining_len = self.len - len; + let s = ptr::slice_from_raw_parts_mut(self.as_mut_ptr().add(len), remaining_len); + self.len = len; + ptr::drop_in_place(s); + } + } + + /// Resizes the Vec in-place so that len is equal to new_len. + /// + /// If new_len is greater than len, the Vec is extended by the + /// difference, with each additional slot filled with value. If + /// new_len is less than len, the Vec is simply truncated. + /// + /// See also [`resize_default`](Self::resize_default). + pub fn resize(&mut self, new_len: usize, value: T) -> Result<(), ()> + where + T: Clone, + { + if new_len > self.capacity() { + return Err(()); + } + + if new_len > self.len { + while self.len < new_len { + self.push(value.clone()).ok(); + } + } else { + self.truncate(new_len); + } + + Ok(()) + } + + /// Resizes the `Vec` in-place so that `len` is equal to `new_len`. + /// + /// If `new_len` is greater than `len`, the `Vec` is extended by the + /// difference, with each additional slot filled with `Default::default()`. + /// If `new_len` is less than `len`, the `Vec` is simply truncated. + /// + /// See also [`resize`](Self::resize). + pub fn resize_default(&mut self, new_len: usize) -> Result<(), ()> + where + T: Clone + Default, + { + self.resize(new_len, T::default()) + } + + /// Forces the length of the vector to `new_len`. + /// + /// This is a low-level operation that maintains none of the normal + /// invariants of the type. Normally changing the length of a vector + /// is done using one of the safe operations instead, such as + /// [`truncate`], [`resize`], [`extend`], or [`clear`]. + /// + /// [`truncate`]: Self::truncate + /// [`resize`]: Self::resize + /// [`extend`]: core::iter::Extend + /// [`clear`]: Self::clear + /// + /// # Safety + /// + /// - `new_len` must be less than or equal to [`capacity()`]. + /// - The elements at `old_len..new_len` must be initialized. + /// + /// [`capacity()`]: Self::capacity + /// + /// # Examples + /// + /// This method can be useful for situations in which the vector + /// is serving as a buffer for other code, particularly over FFI: + /// + /// ```no_run + /// # #![allow(dead_code)] + /// use heapless::Vec; + /// + /// # // This is just a minimal skeleton for the doc example; + /// # // don't use this as a starting point for a real library. + /// # pub struct StreamWrapper { strm: *mut core::ffi::c_void } + /// # const Z_OK: i32 = 0; + /// # extern "C" { + /// # fn deflateGetDictionary( + /// # strm: *mut core::ffi::c_void, + /// # dictionary: *mut u8, + /// # dictLength: *mut usize, + /// # ) -> i32; + /// # } + /// # impl StreamWrapper { + /// pub fn get_dictionary(&self) -> Option> { + /// // Per the FFI method's docs, "32768 bytes is always enough". + /// let mut dict = Vec::new(); + /// let mut dict_length = 0; + /// // SAFETY: When `deflateGetDictionary` returns `Z_OK`, it holds that: + /// // 1. `dict_length` elements were initialized. + /// // 2. `dict_length` <= the capacity (32_768) + /// // which makes `set_len` safe to call. + /// unsafe { + /// // Make the FFI call... + /// let r = deflateGetDictionary(self.strm, dict.as_mut_ptr(), &mut dict_length); + /// if r == Z_OK { + /// // ...and update the length to what was initialized. + /// dict.set_len(dict_length); + /// Some(dict) + /// } else { + /// None + /// } + /// } + /// } + /// # } + /// ``` + /// + /// While the following example is sound, there is a memory leak since + /// the inner vectors were not freed prior to the `set_len` call: + /// + /// ``` + /// use core::iter::FromIterator; + /// use heapless::Vec; + /// + /// let mut vec = Vec::, 3>::from_iter( + /// [ + /// Vec::from_iter([1, 0, 0].iter().cloned()), + /// Vec::from_iter([0, 1, 0].iter().cloned()), + /// Vec::from_iter([0, 0, 1].iter().cloned()), + /// ] + /// .iter() + /// .cloned() + /// ); + /// // SAFETY: + /// // 1. `old_len..0` is empty so no elements need to be initialized. + /// // 2. `0 <= capacity` always holds whatever `capacity` is. + /// unsafe { + /// vec.set_len(0); + /// } + /// ``` + /// + /// Normally, here, one would use [`clear`] instead to correctly drop + /// the contents and thus not leak memory. + pub unsafe fn set_len(&mut self, new_len: usize) { + debug_assert!(new_len <= self.capacity()); + + self.len = new_len + } + + /// Removes an element from the vector and returns it. + /// + /// The removed element is replaced by the last element of the vector. + /// + /// This does not preserve ordering, but is O(1). + /// + /// # Panics + /// + /// Panics if `index` is out of bounds. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + ///// use heapless::consts::*; + /// + /// let mut v: Vec<_, 8> = Vec::new(); + /// v.push("foo").unwrap(); + /// v.push("bar").unwrap(); + /// v.push("baz").unwrap(); + /// v.push("qux").unwrap(); + /// + /// assert_eq!(v.swap_remove(1), "bar"); + /// assert_eq!(&*v, ["foo", "qux", "baz"]); + /// + /// assert_eq!(v.swap_remove(0), "foo"); + /// assert_eq!(&*v, ["baz", "qux"]); + /// ``` + pub fn swap_remove(&mut self, index: usize) -> T { + assert!(index < self.len); + unsafe { self.swap_remove_unchecked(index) } + } + + /// Removes an element from the vector and returns it. + /// + /// The removed element is replaced by the last element of the vector. + /// + /// This does not preserve ordering, but is O(1). + /// + /// # Safety + /// + /// Assumes `index` within bounds. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut v: Vec<_, 8> = Vec::new(); + /// v.push("foo").unwrap(); + /// v.push("bar").unwrap(); + /// v.push("baz").unwrap(); + /// v.push("qux").unwrap(); + /// + /// assert_eq!(unsafe { v.swap_remove_unchecked(1) }, "bar"); + /// assert_eq!(&*v, ["foo", "qux", "baz"]); + /// + /// assert_eq!(unsafe { v.swap_remove_unchecked(0) }, "foo"); + /// assert_eq!(&*v, ["baz", "qux"]); + /// ``` + pub unsafe fn swap_remove_unchecked(&mut self, index: usize) -> T { + let length = self.len(); + debug_assert!(index < length); + let value = ptr::read(self.as_ptr().add(index)); + let base_ptr = self.as_mut_ptr(); + ptr::copy(base_ptr.add(length - 1), base_ptr.add(index), 1); + self.len -= 1; + value + } + + /// Returns true if the vec is full + #[inline] + pub fn is_full(&self) -> bool { + self.len == self.capacity() + } + + /// Returns true if the vec is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns `true` if `needle` is a prefix of the Vec. + /// + /// Always returns `true` if `needle` is an empty slice. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let v: Vec<_, 8> = Vec::from_slice(b"abc").unwrap(); + /// assert_eq!(v.starts_with(b""), true); + /// assert_eq!(v.starts_with(b"ab"), true); + /// assert_eq!(v.starts_with(b"bc"), false); + /// ``` + #[inline] + pub fn starts_with(&self, needle: &[T]) -> bool + where + T: PartialEq, + { + let n = needle.len(); + self.len >= n && needle == &self[..n] + } + + /// Returns `true` if `needle` is a suffix of the Vec. + /// + /// Always returns `true` if `needle` is an empty slice. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let v: Vec<_, 8> = Vec::from_slice(b"abc").unwrap(); + /// assert_eq!(v.ends_with(b""), true); + /// assert_eq!(v.ends_with(b"ab"), false); + /// assert_eq!(v.ends_with(b"bc"), true); + /// ``` + #[inline] + pub fn ends_with(&self, needle: &[T]) -> bool + where + T: PartialEq, + { + let (v, n) = (self.len(), needle.len()); + v >= n && needle == &self[v - n..] + } + + /// Inserts an element at position `index` within the vector, shifting all + /// elements after it to the right. + /// + /// Returns back the `element` if the vector is full. + /// + /// # Panics + /// + /// Panics if `index > len`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3]).unwrap(); + /// vec.insert(1, 4); + /// assert_eq!(vec, [1, 4, 2, 3]); + /// vec.insert(4, 5); + /// assert_eq!(vec, [1, 4, 2, 3, 5]); + /// ``` + pub fn insert(&mut self, index: usize, element: T) -> Result<(), T> { + let len = self.len(); + if index > len { + panic!( + "insertion index (is {}) should be <= len (is {})", + index, len + ); + } + + // check there's space for the new element + if self.is_full() { + return Err(element); + } + + unsafe { + // infallible + // The spot to put the new value + { + let p = self.as_mut_ptr().add(index); + // Shift everything over to make space. (Duplicating the + // `index`th element into two consecutive places.) + ptr::copy(p, p.offset(1), len - index); + // Write it in, overwriting the first copy of the `index`th + // element. + ptr::write(p, element); + } + self.set_len(len + 1); + } + + Ok(()) + } + + /// Removes and returns the element at position `index` within the vector, + /// shifting all elements after it to the left. + /// + /// Note: Because this shifts over the remaining elements, it has a + /// worst-case performance of *O*(*n*). If you don't need the order of + /// elements to be preserved, use [`swap_remove`] instead. If you'd like to + /// remove elements from the beginning of the `Vec`, consider using + /// [`Deque::pop_front`] instead. + /// + /// [`swap_remove`]: Vec::swap_remove + /// [`Deque::pop_front`]: crate::Deque::pop_front + /// + /// # Panics + /// + /// Panics if `index` is out of bounds. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut v: Vec<_, 8> = Vec::from_slice(&[1, 2, 3]).unwrap(); + /// assert_eq!(v.remove(1), 2); + /// assert_eq!(v, [1, 3]); + /// ``` + pub fn remove(&mut self, index: usize) -> T { + let len = self.len(); + if index >= len { + panic!("removal index (is {}) should be < len (is {})", index, len); + } + unsafe { + // infallible + let ret; + { + // the place we are taking from. + let ptr = self.as_mut_ptr().add(index); + // copy it out, unsafely having a copy of the value on + // the stack and in the vector at the same time. + ret = ptr::read(ptr); + + // Shift everything down to fill in that spot. + ptr::copy(ptr.offset(1), ptr, len - index - 1); + } + self.set_len(len - 1); + ret + } + } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `f(&e)` returns `false`. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3, 4]).unwrap(); + /// vec.retain(|&x| x % 2 == 0); + /// assert_eq!(vec, [2, 4]); + /// ``` + /// + /// Because the elements are visited exactly once in the original order, + /// external state may be used to decide which elements to keep. + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3, 4, 5]).unwrap(); + /// let keep = [false, true, true, false, true]; + /// let mut iter = keep.iter(); + /// vec.retain(|_| *iter.next().unwrap()); + /// assert_eq!(vec, [2, 3, 5]); + /// ``` + pub fn retain(&mut self, mut f: F) + where + F: FnMut(&T) -> bool, + { + self.retain_mut(|elem| f(elem)); + } + + /// Retains only the elements specified by the predicate, passing a mutable reference to it. + /// + /// In other words, remove all elements `e` such that `f(&mut e)` returns `false`. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3, 4]).unwrap(); + /// vec.retain_mut(|x| if *x <= 3 { + /// *x += 1; + /// true + /// } else { + /// false + /// }); + /// assert_eq!(vec, [2, 3, 4]); + /// ``` + pub fn retain_mut(&mut self, mut f: F) + where + F: FnMut(&mut T) -> bool, + { + let original_len = self.len(); + // Avoid double drop if the drop guard is not executed, + // since we may make some holes during the process. + unsafe { self.set_len(0) }; + + // Vec: [Kept, Kept, Hole, Hole, Hole, Hole, Unchecked, Unchecked] + // |<- processed len ->| ^- next to check + // |<- deleted cnt ->| + // |<- original_len ->| + // Kept: Elements which predicate returns true on. + // Hole: Moved or dropped element slot. + // Unchecked: Unchecked valid elements. + // + // This drop guard will be invoked when predicate or `drop` of element panicked. + // It shifts unchecked elements to cover holes and `set_len` to the correct length. + // In cases when predicate and `drop` never panick, it will be optimized out. + struct BackshiftOnDrop<'a, T, const N: usize> { + v: &'a mut Vec, + processed_len: usize, + deleted_cnt: usize, + original_len: usize, + } + + impl Drop for BackshiftOnDrop<'_, T, N> { + fn drop(&mut self) { + if self.deleted_cnt > 0 { + // SAFETY: Trailing unchecked items must be valid since we never touch them. + unsafe { + ptr::copy( + self.v.as_ptr().add(self.processed_len), + self.v + .as_mut_ptr() + .add(self.processed_len - self.deleted_cnt), + self.original_len - self.processed_len, + ); + } + } + // SAFETY: After filling holes, all items are in contiguous memory. + unsafe { + self.v.set_len(self.original_len - self.deleted_cnt); + } + } + } + + let mut g = BackshiftOnDrop { + v: self, + processed_len: 0, + deleted_cnt: 0, + original_len, + }; + + fn process_loop( + original_len: usize, + f: &mut F, + g: &mut BackshiftOnDrop<'_, T, N>, + ) where + F: FnMut(&mut T) -> bool, + { + while g.processed_len != original_len { + let p = g.v.as_mut_ptr(); + // SAFETY: Unchecked element must be valid. + let cur = unsafe { &mut *p.add(g.processed_len) }; + if !f(cur) { + // Advance early to avoid double drop if `drop_in_place` panicked. + g.processed_len += 1; + g.deleted_cnt += 1; + // SAFETY: We never touch this element again after dropped. + unsafe { ptr::drop_in_place(cur) }; + // We already advanced the counter. + if DELETED { + continue; + } else { + break; + } + } + if DELETED { + // SAFETY: `deleted_cnt` > 0, so the hole slot must not overlap with current element. + // We use copy for move, and never touch this element again. + unsafe { + let hole_slot = p.add(g.processed_len - g.deleted_cnt); + ptr::copy_nonoverlapping(cur, hole_slot, 1); + } + } + g.processed_len += 1; + } + } + + // Stage 1: Nothing was deleted. + process_loop::(original_len, &mut f, &mut g); + + // Stage 2: Some elements were deleted. + process_loop::(original_len, &mut f, &mut g); + + // All item are processed. This can be optimized to `set_len` by LLVM. + drop(g); + } +} + +// Trait implementations + +impl Default for Vec { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for Vec +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <[T] as fmt::Debug>::fmt(self, f) + } +} + +impl fmt::Write for Vec { + fn write_str(&mut self, s: &str) -> fmt::Result { + match self.extend_from_slice(s.as_bytes()) { + Ok(()) => Ok(()), + Err(_) => Err(fmt::Error), + } + } +} + +impl Drop for Vec { + fn drop(&mut self) { + // We drop each element used in the vector by turning into a &mut[T] + unsafe { + ptr::drop_in_place(self.as_mut_slice()); + } + } +} + +impl<'a, T: Clone, const N: usize> TryFrom<&'a [T]> for Vec { + type Error = (); + + fn try_from(slice: &'a [T]) -> Result { + Vec::from_slice(slice) + } +} + +impl Extend for Vec { + fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + self.extend(iter) + } +} + +impl<'a, T, const N: usize> Extend<&'a T> for Vec +where + T: 'a + Copy, +{ + fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + self.extend(iter.into_iter().cloned()) + } +} + +impl hash::Hash for Vec +where + T: core::hash::Hash, +{ + fn hash(&self, state: &mut H) { + <[T] as hash::Hash>::hash(self, state) + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a Vec { + type Item = &'a T; + type IntoIter = slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a mut Vec { + type Item = &'a mut T; + type IntoIter = slice::IterMut<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} + +impl FromIterator for Vec { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut vec = Vec::new(); + for i in iter { + vec.push(i).ok().expect("Vec::from_iter overflow"); + } + vec + } +} + +/// An iterator that moves out of an [`Vec`][`Vec`]. +/// +/// This struct is created by calling the `into_iter` method on [`Vec`][`Vec`]. +pub struct IntoIter { + vec: Vec, + next: usize, +} + +impl Iterator for IntoIter { + type Item = T; + fn next(&mut self) -> Option { + if self.next < self.vec.len() { + let item = unsafe { + (self.vec.buffer.get_unchecked_mut(self.next).as_ptr() as *const T).read() + }; + self.next += 1; + Some(item) + } else { + None + } + } +} + +impl Clone for IntoIter +where + T: Clone, +{ + fn clone(&self) -> Self { + let mut vec = Vec::new(); + + if self.next < self.vec.len() { + let s = unsafe { + slice::from_raw_parts( + (self.vec.buffer.as_ptr() as *const T).add(self.next), + self.vec.len() - self.next, + ) + }; + vec.extend_from_slice(s).ok(); + } + + Self { vec, next: 0 } + } +} + +impl Drop for IntoIter { + fn drop(&mut self) { + unsafe { + // Drop all the elements that have not been moved out of vec + ptr::drop_in_place(&mut self.vec.as_mut_slice()[self.next..]); + // Prevent dropping of other elements + self.vec.len = 0; + } + } +} + +impl IntoIterator for Vec { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { vec: self, next: 0 } + } +} + +impl PartialEq> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(self, &**other) + } +} + +// Vec == [B] +impl PartialEq<[B]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &[B]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// [B] == Vec +impl PartialEq> for [B] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == &[B] +impl PartialEq<&[B]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &&[B]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// &[B] == Vec +impl PartialEq> for &[B] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == &mut [B] +impl PartialEq<&mut [B]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &&mut [B]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// &mut [B] == Vec +impl PartialEq> for &mut [B] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == [B; M] +// Equality does not require equal capacity +impl PartialEq<[B; M]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &[B; M]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// [B; M] == Vec +// Equality does not require equal capacity +impl PartialEq> for [B; M] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == &[B; M] +// Equality does not require equal capacity +impl PartialEq<&[B; M]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &&[B; M]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// &[B; M] == Vec +// Equality does not require equal capacity +impl PartialEq> for &[B; M] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Implements Eq if underlying data is Eq +impl Eq for Vec where T: Eq {} + +impl PartialOrd> for Vec +where + T: PartialOrd, +{ + fn partial_cmp(&self, other: &Vec) -> Option { + PartialOrd::partial_cmp(&**self, &**other) + } +} + +impl Ord for Vec +where + T: Ord, +{ + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + Ord::cmp(&**self, &**other) + } +} + +impl ops::Deref for Vec { + type Target = [T]; + + fn deref(&self) -> &[T] { + self.as_slice() + } +} + +impl ops::DerefMut for Vec { + fn deref_mut(&mut self) -> &mut [T] { + self.as_mut_slice() + } +} + +impl AsRef> for Vec { + #[inline] + fn as_ref(&self) -> &Self { + self + } +} + +impl AsMut> for Vec { + #[inline] + fn as_mut(&mut self) -> &mut Self { + self + } +} + +impl AsRef<[T]> for Vec { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +impl AsMut<[T]> for Vec { + #[inline] + fn as_mut(&mut self) -> &mut [T] { + self + } +} + +impl Clone for Vec +where + T: Clone, +{ + fn clone(&self) -> Self { + self.clone() + } +} + +#[cfg(test)] +mod tests { + use core::fmt::Write; + + use super::Vec; + + macro_rules! droppable { + () => { + static COUNT: core::sync::atomic::AtomicI32 = core::sync::atomic::AtomicI32::new(0); + + #[derive(Eq, Ord, PartialEq, PartialOrd)] + struct Droppable(i32); + impl Droppable { + fn new() -> Self { + COUNT.fetch_add(1, core::sync::atomic::Ordering::Relaxed); + Droppable(Self::count()) + } + + fn count() -> i32 { + COUNT.load(core::sync::atomic::Ordering::Relaxed) + } + } + impl Drop for Droppable { + fn drop(&mut self) { + COUNT.fetch_sub(1, core::sync::atomic::Ordering::Relaxed); + } + } + }; + } + + #[test] + fn static_new() { + static mut _V: Vec = Vec::new(); + } + + #[test] + fn stack_new() { + let mut _v: Vec = Vec::new(); + } + + #[test] + fn is_full_empty() { + let mut v: Vec = Vec::new(); + + assert!(v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(v.is_full()); + } + + #[test] + fn drop() { + droppable!(); + + { + let mut v: Vec = Vec::new(); + v.push(Droppable::new()).ok().unwrap(); + v.push(Droppable::new()).ok().unwrap(); + v.pop().unwrap(); + } + + assert_eq!(Droppable::count(), 0); + + { + let mut v: Vec = Vec::new(); + v.push(Droppable::new()).ok().unwrap(); + v.push(Droppable::new()).ok().unwrap(); + } + + assert_eq!(Droppable::count(), 0); + } + + #[test] + fn eq() { + let mut xs: Vec = Vec::new(); + let mut ys: Vec = Vec::new(); + + assert_eq!(xs, ys); + + xs.push(1).unwrap(); + ys.push(1).unwrap(); + + assert_eq!(xs, ys); + } + + #[test] + fn cmp() { + let mut xs: Vec = Vec::new(); + let mut ys: Vec = Vec::new(); + + assert_eq!(xs, ys); + + xs.push(1).unwrap(); + ys.push(2).unwrap(); + + assert!(xs < ys); + } + + #[test] + fn cmp_heterogenous_size() { + let mut xs: Vec = Vec::new(); + let mut ys: Vec = Vec::new(); + + assert_eq!(xs, ys); + + xs.push(1).unwrap(); + ys.push(2).unwrap(); + + assert!(xs < ys); + } + + #[test] + fn cmp_with_arrays_and_slices() { + let mut xs: Vec = Vec::new(); + xs.push(1).unwrap(); + + let array = [1]; + + assert_eq!(xs, array); + assert_eq!(array, xs); + + assert_eq!(xs, array.as_slice()); + assert_eq!(array.as_slice(), xs); + + assert_eq!(xs, &array); + assert_eq!(&array, xs); + + let longer_array = [1; 20]; + + assert_ne!(xs, longer_array); + assert_ne!(longer_array, xs); + } + + #[test] + fn full() { + let mut v: Vec = Vec::new(); + + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + assert!(v.push(4).is_err()); + } + + #[test] + fn iter() { + let mut v: Vec = Vec::new(); + + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + let mut items = v.iter(); + + assert_eq!(items.next(), Some(&0)); + assert_eq!(items.next(), Some(&1)); + assert_eq!(items.next(), Some(&2)); + assert_eq!(items.next(), Some(&3)); + assert_eq!(items.next(), None); + } + + #[test] + fn iter_mut() { + let mut v: Vec = Vec::new(); + + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + let mut items = v.iter_mut(); + + assert_eq!(items.next(), Some(&mut 0)); + assert_eq!(items.next(), Some(&mut 1)); + assert_eq!(items.next(), Some(&mut 2)); + assert_eq!(items.next(), Some(&mut 3)); + assert_eq!(items.next(), None); + } + + #[test] + fn collect_from_iter() { + let slice = &[1, 2, 3]; + let vec: Vec = slice.iter().cloned().collect(); + assert_eq!(&vec, slice); + } + + #[test] + #[should_panic] + fn collect_from_iter_overfull() { + let slice = &[1, 2, 3]; + let _vec = slice.iter().cloned().collect::>(); + } + + #[test] + fn iter_move() { + let mut v: Vec = Vec::new(); + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + let mut items = v.into_iter(); + + assert_eq!(items.next(), Some(0)); + assert_eq!(items.next(), Some(1)); + assert_eq!(items.next(), Some(2)); + assert_eq!(items.next(), Some(3)); + assert_eq!(items.next(), None); + } + + #[test] + fn iter_move_drop() { + droppable!(); + + { + let mut vec: Vec = Vec::new(); + vec.push(Droppable::new()).ok().unwrap(); + vec.push(Droppable::new()).ok().unwrap(); + let mut items = vec.into_iter(); + // Move all + let _ = items.next(); + let _ = items.next(); + } + + assert_eq!(Droppable::count(), 0); + + { + let mut vec: Vec = Vec::new(); + vec.push(Droppable::new()).ok().unwrap(); + vec.push(Droppable::new()).ok().unwrap(); + let _items = vec.into_iter(); + // Move none + } + + assert_eq!(Droppable::count(), 0); + + { + let mut vec: Vec = Vec::new(); + vec.push(Droppable::new()).ok().unwrap(); + vec.push(Droppable::new()).ok().unwrap(); + let mut items = vec.into_iter(); + let _ = items.next(); // Move partly + } + + assert_eq!(Droppable::count(), 0); + } + + #[test] + fn push_and_pop() { + let mut v: Vec = Vec::new(); + assert_eq!(v.len(), 0); + + assert_eq!(v.pop(), None); + assert_eq!(v.len(), 0); + + v.push(0).unwrap(); + assert_eq!(v.len(), 1); + + assert_eq!(v.pop(), Some(0)); + assert_eq!(v.len(), 0); + + assert_eq!(v.pop(), None); + assert_eq!(v.len(), 0); + } + + #[test] + fn resize_size_limit() { + let mut v: Vec = Vec::new(); + + v.resize(0, 0).unwrap(); + v.resize(4, 0).unwrap(); + v.resize(5, 0).expect_err("full"); + } + + #[test] + fn resize_length_cases() { + let mut v: Vec = Vec::new(); + + assert_eq!(v.len(), 0); + + // Grow by 1 + v.resize(1, 0).unwrap(); + assert_eq!(v.len(), 1); + + // Grow by 2 + v.resize(3, 0).unwrap(); + assert_eq!(v.len(), 3); + + // Resize to current size + v.resize(3, 0).unwrap(); + assert_eq!(v.len(), 3); + + // Shrink by 1 + v.resize(2, 0).unwrap(); + assert_eq!(v.len(), 2); + + // Shrink by 2 + v.resize(0, 0).unwrap(); + assert_eq!(v.len(), 0); + } + + #[test] + fn resize_contents() { + let mut v: Vec = Vec::new(); + + // New entries take supplied value when growing + v.resize(1, 17).unwrap(); + assert_eq!(v[0], 17); + + // Old values aren't changed when growing + v.resize(2, 18).unwrap(); + assert_eq!(v[0], 17); + assert_eq!(v[1], 18); + + // Old values aren't changed when length unchanged + v.resize(2, 0).unwrap(); + assert_eq!(v[0], 17); + assert_eq!(v[1], 18); + + // Old values aren't changed when shrinking + v.resize(1, 0).unwrap(); + assert_eq!(v[0], 17); + } + + #[test] + fn resize_default() { + let mut v: Vec = Vec::new(); + + // resize_default is implemented using resize, so just check the + // correct value is being written. + v.resize_default(1).unwrap(); + assert_eq!(v[0], 0); + } + + #[test] + fn write() { + let mut v: Vec = Vec::new(); + write!(v, "{:x}", 1234).unwrap(); + assert_eq!(&v[..], b"4d2"); + } + + #[test] + fn extend_from_slice() { + let mut v: Vec = Vec::new(); + assert_eq!(v.len(), 0); + v.extend_from_slice(&[1, 2]).unwrap(); + assert_eq!(v.len(), 2); + assert_eq!(v.as_slice(), &[1, 2]); + v.extend_from_slice(&[3]).unwrap(); + assert_eq!(v.len(), 3); + assert_eq!(v.as_slice(), &[1, 2, 3]); + assert!(v.extend_from_slice(&[4, 5]).is_err()); + assert_eq!(v.len(), 3); + assert_eq!(v.as_slice(), &[1, 2, 3]); + } + + #[test] + fn from_slice() { + // Successful construction + let v: Vec = Vec::from_slice(&[1, 2, 3]).unwrap(); + assert_eq!(v.len(), 3); + assert_eq!(v.as_slice(), &[1, 2, 3]); + + // Slice too large + assert!(Vec::::from_slice(&[1, 2, 3]).is_err()); + } + + #[test] + fn starts_with() { + let v: Vec<_, 8> = Vec::from_slice(b"ab").unwrap(); + assert!(v.starts_with(&[])); + assert!(v.starts_with(b"")); + assert!(v.starts_with(b"a")); + assert!(v.starts_with(b"ab")); + assert!(!v.starts_with(b"abc")); + assert!(!v.starts_with(b"ba")); + assert!(!v.starts_with(b"b")); + } + + #[test] + fn ends_with() { + let v: Vec<_, 8> = Vec::from_slice(b"ab").unwrap(); + assert!(v.ends_with(&[])); + assert!(v.ends_with(b"")); + assert!(v.ends_with(b"b")); + assert!(v.ends_with(b"ab")); + assert!(!v.ends_with(b"abc")); + assert!(!v.ends_with(b"ba")); + assert!(!v.ends_with(b"a")); + } + + #[test] + fn zero_capacity() { + let mut v: Vec = Vec::new(); + // Validate capacity + assert_eq!(v.capacity(), 0); + + // Make sure there is no capacity + assert!(v.push(1).is_err()); + + // Validate length + assert_eq!(v.len(), 0); + + // Validate pop + assert_eq!(v.pop(), None); + + // Validate slice + assert_eq!(v.as_slice(), &[]); + + // Validate empty + assert!(v.is_empty()); + + // Validate full + assert!(v.is_full()); + } +} diff --git a/rs-matter/src/utils/writebuf.rs b/rs-matter/src/utils/storage/writebuf.rs similarity index 99% rename from rs-matter/src/utils/writebuf.rs rename to rs-matter/src/utils/storage/writebuf.rs index f3363cc4..0dcb873e 100644 --- a/rs-matter/src/utils/writebuf.rs +++ b/rs-matter/src/utils/storage/writebuf.rs @@ -213,7 +213,7 @@ impl<'a> WriteBuf<'a> { #[cfg(test)] mod tests { - use crate::utils::writebuf::*; + use crate::utils::storage::WriteBuf; #[test] fn test_append_le_with_success() { diff --git a/rs-matter/src/utils/sync.rs b/rs-matter/src/utils/sync.rs new file mode 100644 index 00000000..d1174467 --- /dev/null +++ b/rs-matter/src/utils/sync.rs @@ -0,0 +1,26 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +pub use mutex::*; +pub use notification::*; +pub use signal::*; + +pub mod blocking; + +mod mutex; +mod notification; +mod signal; diff --git a/rs-matter/src/utils/sync/blocking.rs b/rs-matter/src/utils/sync/blocking.rs new file mode 100644 index 00000000..6d599dbd --- /dev/null +++ b/rs-matter/src/utils/sync/blocking.rs @@ -0,0 +1,253 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +pub use mutex::*; + +mod mutex { + //! A variation of the `embassy-sync` blocking mutex that allows in-place initialization + //! of the mutex with `Mutex::init(..) -> impl Init`. + //! Check `embassy_sync::blocking_mutex::Mutex` for the original implementation. + + #![allow(clippy::should_implement_trait)] + + use core::cell::UnsafeCell; + + use embassy_sync::blocking_mutex::raw::{self, RawMutex}; + + use crate::utils::init::{init, Init, UnsafeCellInit}; + + /// Blocking mutex (not async) + /// + /// Provides a blocking mutual exclusion primitive backed by an implementation of [`raw::RawMutex`]. + /// + /// Which implementation you select depends on the context in which you're using the mutex, and you can choose which kind + /// of interior mutability fits your use case. + /// + /// Use [`CriticalSectionMutex`] when data can be shared between threads and interrupts. + /// + /// Use [`NoopMutex`] when data is only shared between tasks running on the same executor. + /// + /// Use [`ThreadModeMutex`] when data is shared between tasks running on the same executor but you want a global singleton. + /// + /// In all cases, the blocking mutex is intended to be short lived and not held across await points. + /// Use the async [`Mutex`](crate::mutex::Mutex) if you need a lock that is held across await points. + pub struct Mutex { + // NOTE: `raw` must be FIRST, so when using ThreadModeMutex the "can't drop in non-thread-mode" gets + // to run BEFORE dropping `data`. + raw: R, + data: UnsafeCell, + } + + unsafe impl Send for Mutex {} + unsafe impl Sync for Mutex {} + + impl Mutex { + /// Creates a new mutex in an unlocked state ready for use. + #[inline] + pub const fn new(val: T) -> Mutex { + Mutex { + raw: R::INIT, + data: UnsafeCell::new(val), + } + } + + /// Creates a mutex in-place initializer in an unlocked state ready for use. + pub fn init>(val: I) -> impl Init { + init!(Self { + raw: R::INIT, + data <- UnsafeCell::init(val), + }) + } + + /// Creates a critical section and grants temporary access to the protected data. + pub fn lock(&self, f: impl FnOnce(&T) -> U) -> U { + self.raw.lock(|| { + let ptr = self.data.get() as *const T; + let inner = unsafe { &*ptr }; + f(inner) + }) + } + } + + impl Mutex { + /// Creates a new mutex based on a pre-existing raw mutex. + /// + /// This allows creating a mutex in a constant context on stable Rust. + #[inline] + pub const fn const_new(raw_mutex: R, val: T) -> Mutex { + Mutex { + raw: raw_mutex, + data: UnsafeCell::new(val), + } + } + + /// Consumes this mutex, returning the underlying data. + #[inline] + pub fn into_inner(self) -> T { + self.data.into_inner() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place---the mutable borrow statically guarantees no locks exist. + #[inline] + pub fn get_mut(&mut self) -> &mut T { + unsafe { &mut *self.data.get() } + } + } + + /// A mutex that allows borrowing data across executors and interrupts. + /// + /// # Safety + /// + /// This mutex is safe to share between different executors and interrupts. + pub type CriticalSectionMutex = Mutex; + + /// A mutex that allows borrowing data in the context of a single executor. + /// + /// # Safety + /// + /// **This Mutex is only safe within a single executor.** + pub type NoopMutex = Mutex; + + impl Mutex { + /// Borrows the data for the duration of the critical section + pub fn borrow<'cs>(&'cs self, _cs: critical_section::CriticalSection<'cs>) -> &'cs T { + let ptr = self.data.get() as *const T; + unsafe { &*ptr } + } + } + + impl Mutex { + /// Borrows the data + pub fn borrow(&self) -> &T { + let ptr = self.data.get() as *const T; + unsafe { &*ptr } + } + } + + // // ThreadModeMutex does NOT use the generic mutex from above because it's special: + // // it's Send+Sync even if T: !Send. There's no way to do that without specialization (I think?). + // // + // // There's still a ThreadModeRawMutex for use with the generic Mutex (handy with Channel, for example), + // // but that will require T: Send even though it shouldn't be needed. + + // #[cfg(any(cortex_m, feature = "std"))] + // pub use thread_mode_mutex::*; + // #[cfg(any(cortex_m, feature = "std"))] + // mod thread_mode_mutex { + // use super::*; + + // /// A "mutex" that only allows borrowing from thread mode. + // /// + // /// # Safety + // /// + // /// **This Mutex is only safe on single-core systems.** + // /// + // /// On multi-core systems, a `ThreadModeMutex` **is not sufficient** to ensure exclusive access. + // pub struct ThreadModeMutex { + // inner: UnsafeCell, + // } + + // // NOTE: ThreadModeMutex only allows borrowing from one execution context ever: thread mode. + // // Therefore it cannot be used to send non-sendable stuff between execution contexts, so it can + // // be Send+Sync even if T is not Send (unlike CriticalSectionMutex) + // unsafe impl Sync for ThreadModeMutex {} + // unsafe impl Send for ThreadModeMutex {} + + // impl ThreadModeMutex { + // /// Creates a new mutex + // pub const fn new(value: T) -> Self { + // ThreadModeMutex { + // inner: UnsafeCell::new(value), + // } + // } + // } + + // impl ThreadModeMutex { + // /// Lock the `ThreadModeMutex`, granting access to the data. + // /// + // /// # Panics + // /// + // /// This will panic if not currently running in thread mode. + // pub fn lock(&self, f: impl FnOnce(&T) -> R) -> R { + // f(self.borrow()) + // } + + // /// Borrows the data + // /// + // /// # Panics + // /// + // /// This will panic if not currently running in thread mode. + // pub fn borrow(&self) -> &T { + // assert!( + // raw::in_thread_mode(), + // "ThreadModeMutex can only be borrowed from thread mode." + // ); + // unsafe { &*self.inner.get() } + // } + // } + + // impl Drop for ThreadModeMutex { + // fn drop(&mut self) { + // // Only allow dropping from thread mode. Dropping calls drop on the inner `T`, so + // // `drop` needs the same guarantees as `lock`. `ThreadModeMutex` is Send even if + // // T isn't, so without this check a user could create a ThreadModeMutex in thread mode, + // // send it to interrupt context and drop it there, which would "send" a T even if T is not Send. + // assert!( + // raw::in_thread_mode(), + // "ThreadModeMutex can only be dropped from thread mode." + // ); + + // // Drop of the inner `T` happens after this. + // } + // } + // } +} + +pub mod raw { + #[cfg(feature = "std")] + pub use std::*; + + #[cfg(feature = "std")] + mod std { + use embassy_sync::blocking_mutex::raw::RawMutex; + + /// An `embassy-sync` `RawMutex` implementation using `std::sync::Mutex`. + // TODO: Upstream into `embassy-sync` itself. + #[derive(Default)] + pub struct StdRawMutex(std::sync::Mutex<()>); + + impl StdRawMutex { + pub const fn new() -> Self { + Self(std::sync::Mutex::new(())) + } + } + + unsafe impl RawMutex for StdRawMutex { + #[allow(clippy::declare_interior_mutable_const)] + const INIT: Self = StdRawMutex(std::sync::Mutex::new(())); + + fn lock(&self, f: impl FnOnce() -> R) -> R { + let _guard = self.0.lock().unwrap(); + + f() + } + } + } +} diff --git a/rs-matter/src/utils/ifmutex.rs b/rs-matter/src/utils/sync/mutex.rs similarity index 94% rename from rs-matter/src/utils/ifmutex.rs rename to rs-matter/src/utils/sync/mutex.rs index e9dd8ef1..ab7590a0 100644 --- a/rs-matter/src/utils/ifmutex.rs +++ b/rs-matter/src/utils/sync/mutex.rs @@ -15,13 +15,17 @@ * limitations under the License. */ -//! A variation of the `embassy-sync` async mutex that only locks the mutex if a certain condition on the content of the data holds true. +//! A variation of the `embassy-sync` async mutex that only locks the mutex if a certain +//! condition on the content of the data holds true. //! Check `embassy_sync::Mutex` for the original unconditional implementation. + use core::cell::UnsafeCell; use core::ops::{Deref, DerefMut}; use embassy_sync::blocking_mutex::raw::RawMutex; +use crate::utils::init::{init, Init, UnsafeCellInit}; + use super::signal::Signal; /// Error returned by [`Mutex::try_lock`] @@ -55,6 +59,14 @@ where inner: UnsafeCell::new(value), } } + + /// Creates a mutex in-place initializer with the given value initializer. + pub fn init>(value: I) -> impl Init { + init!(Self { + state: Signal::::new(false), + inner <- UnsafeCell::init(value), + }) + } } impl IfMutex diff --git a/rs-matter/src/utils/notification.rs b/rs-matter/src/utils/sync/notification.rs similarity index 100% rename from rs-matter/src/utils/notification.rs rename to rs-matter/src/utils/sync/notification.rs diff --git a/rs-matter/src/utils/signal.rs b/rs-matter/src/utils/sync/signal.rs similarity index 74% rename from rs-matter/src/utils/signal.rs rename to rs-matter/src/utils/sync/signal.rs index a747f654..cc7fc1c8 100644 --- a/rs-matter/src/utils/signal.rs +++ b/rs-matter/src/utils/sync/signal.rs @@ -15,18 +15,38 @@ * limitations under the License. */ -use core::cell::RefCell; use core::future::poll_fn; use core::task::{Context, Poll}; -use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; +use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::waitqueue::WakerRegistration; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; + +use super::blocking::Mutex; + struct State { state: S, waker: WakerRegistration, } +impl State { + const fn new(state: S) -> Self { + Self { + state, + waker: WakerRegistration::new(), + } + } + + fn init>(state: I) -> impl Init { + init!(Self { + state <- state, + waker: WakerRegistration::new(), + }) + } +} + /// `Signal` is an async synchonization primitive that can be viewed as a generalization of the `embassy_sync::Signal` primitive /// that takes callback closures. /// @@ -38,18 +58,26 @@ struct State { /// The generic nature of `Signal` allows for a wide range of use cases, including the implementation of: /// - the `Notification` primitive /// - the `IfMutex` primitive -pub struct Signal(Mutex>>); +pub struct Signal { + inner: Mutex>>, +} impl Signal where M: RawMutex, { - /// Crate a `Signal` with the given initial state `S`. + /// Create a `Signal` with the given initial state `S`. pub const fn new(state: S) -> Self { - Self(Mutex::new(RefCell::new(State { - state, - waker: WakerRegistration::new(), - }))) + Self { + inner: Mutex::new(RefCell::new(State::new(state))), + } + } + + /// Create a `Signal` in-place initializer with the given initial state initializer `I`. + pub fn init>(state: I) -> impl Init { + init!(Self { + inner <- Mutex::init(RefCell::init(State::init(state))), + }) } // Modify the state `S` and wake up the waiters if necessary. @@ -57,7 +85,7 @@ where where F: FnOnce(&mut S) -> (bool, R), { - self.0.lock(|s| { + self.inner.lock(|s| { let mut s = s.borrow_mut(); let (wake, result) = f(&mut s.state); @@ -83,7 +111,7 @@ where where F: FnOnce(&mut S) -> Option, { - self.0.lock(|s| { + self.inner.lock(|s| { let mut s = s.borrow_mut(); if let Some(result) = f(&mut s.state) { diff --git a/rs-matter/tests/common/attributes.rs b/rs-matter/tests/common/attributes.rs index eabef20c..55a5841b 100644 --- a/rs-matter/tests/common/attributes.rs +++ b/rs-matter/tests/common/attributes.rs @@ -18,7 +18,7 @@ use rs_matter::{ interaction_model::{messages::ib::AttrResp, messages::msg::ReportDataMsg}, tlv::{TLVElement, TLVList, TLVWriter, TagType, ToTLV}, - utils::writebuf::WriteBuf, + utils::storage::WriteBuf, }; /// Assert that the data received in the outbuf matches our expectations diff --git a/rs-matter/tests/common/im_engine.rs b/rs-matter/tests/common/im_engine.rs index 3a0b3b5d..50fd67ca 100644 --- a/rs-matter/tests/common/im_engine.rs +++ b/rs-matter/tests/common/im_engine.rs @@ -65,7 +65,7 @@ use rs_matter::{ }, session::{NocCatIds, ReservedSession, SessionMode}, }, - utils::{buf::PooledBuffers, select::Coalesce}, + utils::{select::Coalesce, storage::pooled::PooledBuffers}, Matter, MATTER_PORT, }; @@ -295,7 +295,7 @@ impl<'a> ImEngine<'a> { ADDR, SessionMode::Case { fab_idx: NonZeroU8::new(1).unwrap(), - cat_ids: cat_ids.clone(), + cat_ids: *cat_ids, }, None, None, diff --git a/rs-matter/tests/tlv_encoding.rs b/rs-matter/tests/tlv_encoding.rs index 5ce9664c..99a0b003 100644 --- a/rs-matter/tests/tlv_encoding.rs +++ b/rs-matter/tests/tlv_encoding.rs @@ -20,7 +20,7 @@ mod tlv_encoding_tests { use rs_matter::bitflags_tlv; use rs_matter::error::Error; use rs_matter::tlv::{get_root_node, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; - use rs_matter::utils::writebuf::WriteBuf; + use rs_matter::utils::storage::WriteBuf; #[derive(PartialEq, Debug, ToTLV, FromTLV)] struct SimpleStruct {