From 80f4ff925568a30f59ed13a832c7ee3ae73130fc Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Sat, 6 Jul 2024 07:31:30 +0000 Subject: [PATCH] In-place initialization Update the example; more ergonomic init of UnsafeCell and MaybeUninit More constructors In-place ctr for the built-in Mdns too enable no_std for pinned-init In-place init for BTP Re-export pinned-init as its API is unstable; document Initializer for GATT peripheral PAtch pinned-init with a forked version that does not need nightly zeroed in pinned-init master no longer has the E generic Use upstream pinned-init In-place initializer for State Stop pretending that ContainerInit has any other use case besides UnsafeCell Maybe and other extensions (Code review feedbacmoveplit storage and sync types into their dedicated modules Fix astro and zeroconf Restore a larger stack for now, until in-place-init is fully utilized Remove commented out code --- examples/onoff_light/src/main.rs | 62 +- examples/onoff_light_bt/src/comm.rs | 2 +- examples/onoff_light_bt/src/main.rs | 9 +- rs-matter/Cargo.toml | 6 +- rs-matter/src/acl.rs | 72 +- rs-matter/src/cert/mod.rs | 4 +- rs-matter/src/core.rs | 121 +- rs-matter/src/data_model/core.rs | 4 +- rs-matter/src/data_model/sdm/noc.rs | 17 +- rs-matter/src/data_model/subscriptions.rs | 18 +- .../data_model/system_model/access_control.rs | 11 +- rs-matter/src/fabric.rs | 36 +- rs-matter/src/interaction_model/core.rs | 2 +- rs-matter/src/mdns.rs | 72 +- rs-matter/src/mdns/astro.rs | 18 +- rs-matter/src/mdns/builtin.rs | 31 +- rs-matter/src/mdns/proto.rs | 2 +- rs-matter/src/mdns/zeroconf.rs | 18 +- rs-matter/src/pairing/qr.rs | 2 +- rs-matter/src/persist.rs | 88 +- rs-matter/src/respond.rs | 2 +- rs-matter/src/secure_channel/case.rs | 2 +- rs-matter/src/secure_channel/common.rs | 2 +- rs-matter/src/secure_channel/pake.rs | 43 +- rs-matter/src/secure_channel/status_report.rs | 2 +- rs-matter/src/tlv/traits.rs | 34 +- rs-matter/src/tlv/writer.rs | 9 +- rs-matter/src/transport/core.rs | 107 +- rs-matter/src/transport/exchange.rs | 3 +- rs-matter/src/transport/network/btp.rs | 28 +- .../src/transport/network/btp/context.rs | 32 +- .../src/transport/network/btp/gatt/bluer.rs | 45 +- .../src/transport/network/btp/session.rs | 30 +- .../transport/network/btp/session/packet.rs | 2 +- rs-matter/src/transport/network/btp/test.rs | 3 +- rs-matter/src/transport/packet.rs | 8 +- rs-matter/src/transport/plain_hdr.rs | 3 +- rs-matter/src/transport/proto_hdr.rs | 3 +- rs-matter/src/transport/session.rs | 23 +- rs-matter/src/utils/cell.rs | 1094 +++++++++++ rs-matter/src/utils/epoch.rs | 17 + rs-matter/src/utils/init.rs | 91 + rs-matter/src/utils/maybe.rs | 215 +++ rs-matter/src/utils/mod.rs | 13 +- .../src/utils/{std_mutex.rs => storage.rs} | 32 +- rs-matter/src/utils/{ => storage}/parsebuf.rs | 2 +- .../src/utils/{buf.rs => storage/pooled.rs} | 23 +- rs-matter/src/utils/{ => storage}/ringbuf.rs | 40 +- rs-matter/src/utils/storage/vec.rs | 1691 +++++++++++++++++ rs-matter/src/utils/{ => storage}/writebuf.rs | 2 +- rs-matter/src/utils/sync.rs | 26 + rs-matter/src/utils/sync/blocking.rs | 253 +++ .../src/utils/{ifmutex.rs => sync/mutex.rs} | 14 +- .../src/utils/{ => sync}/notification.rs | 0 rs-matter/src/utils/{ => sync}/signal.rs | 48 +- rs-matter/tests/common/attributes.rs | 2 +- rs-matter/tests/common/im_engine.rs | 4 +- rs-matter/tests/tlv_encoding.rs | 2 +- 58 files changed, 4185 insertions(+), 360 deletions(-) create mode 100644 rs-matter/src/utils/cell.rs create mode 100644 rs-matter/src/utils/init.rs create mode 100644 rs-matter/src/utils/maybe.rs rename rs-matter/src/utils/{std_mutex.rs => storage.rs} (50%) rename rs-matter/src/utils/{ => storage}/parsebuf.rs (99%) rename rs-matter/src/utils/{buf.rs => storage/pooled.rs} (85%) rename rs-matter/src/utils/{ => storage}/ringbuf.rs (87%) create mode 100644 rs-matter/src/utils/storage/vec.rs rename rs-matter/src/utils/{ => storage}/writebuf.rs (99%) create mode 100644 rs-matter/src/utils/sync.rs create mode 100644 rs-matter/src/utils/sync/blocking.rs rename rs-matter/src/utils/{ifmutex.rs => sync/mutex.rs} (94%) rename rs-matter/src/utils/{ => sync}/notification.rs (100%) rename rs-matter/src/utils/{ => sync}/signal.rs (74%) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1c3d25c2..cc67a468 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -39,12 +39,36 @@ use rs_matter::persist::Psm; use rs_matter::respond::DefaultResponder; use rs_matter::secure_channel::spake2p::VerifierData; use rs_matter::transport::core::MATTER_SOCKET_BIND_ADDR; -use rs_matter::utils::buf::PooledBuffers; +use rs_matter::utils::init::InitMaybeUninit; use rs_matter::utils::select::Coalesce; +use rs_matter::utils::storage::pooled::PooledBuffers; use rs_matter::MATTER_PORT; +use static_cell::StaticCell; mod dev_att; +static DEV_DET: BasicInfoConfig = BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8000, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1", + serial_no: "aabbccdd", + device_name: "OnOff Light", + product_name: "Light123", + vendor_name: "Vendor PQR", +}; + +static DEV_ATT: dev_att::HardCodedDevAtt = dev_att::HardCodedDevAtt::new(); + +static MATTER: StaticCell = StaticCell::new(); + +static BUFFERS: StaticCell> = StaticCell::new(); + +static SUBSCRIPTIONS: StaticCell> = StaticCell::new(); + +static PSM: StaticCell> = StaticCell::new(); + fn main() -> Result<(), Error> { let thread = std::thread::Builder::new() // Increase the stack size until the example can work without stack blowups. @@ -54,7 +78,7 @@ fn main() -> Result<(), Error> { // e.g., an opt-level of "0" will require a several times' larger stack. // // Optimizing/lowering `rs-matter` memory consumption is an ongoing topic. - .stack_size(95 * 1024) + .stack_size(65 * 1024) .spawn(run) .unwrap(); @@ -72,36 +96,22 @@ fn run() -> Result<(), Error> { core::mem::size_of::>() ); - let dev_det = BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8000, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1", - serial_no: "aabbccdd", - device_name: "OnOff Light", - product_name: "Light123", - vendor_name: "Vendor PQR", - }; - - let dev_att = dev_att::HardCodedDevAtt::new(); - - let matter = Matter::new( - &dev_det, - &dev_att, + let matter = MATTER.uninit().init_with(Matter::init( + &DEV_DET, + &DEV_ATT, // NOTE: // For `no_std` environments, provide your own epoch and rand functions here MdnsService::Builtin, rs_matter::utils::epoch::sys_epoch, rs_matter::utils::rand::sys_rand, MATTER_PORT, - ); + )); matter.initialize_transport_buffers()?; info!("Matter initialized"); - let buffers = PooledBuffers::<10, NoopRawMutex, _>::new(0); + let buffers = BUFFERS.uninit().init_with(PooledBuffers::init(0)); info!("IM buffers initialized"); @@ -109,14 +119,14 @@ fn run() -> Result<(), Error> { let on_off = cluster_on_off::OnOffCluster::new(Dataver::new_rand(matter.rand())); - let subscriptions = Subscriptions::<3>::new(); + let subscriptions = SUBSCRIPTIONS.uninit().init_with(Subscriptions::init()); // Assemble our Data Model handler by composing the predefined Root Endpoint handler with our custom On/Off clusters let dm_handler = HandlerCompat(dm_handler(&matter, &on_off)); // Create a default responder capable of handling up to 3 subscriptions // All other subscription requests will be turned down with "resource exhausted" - let responder = DefaultResponder::new(&matter, &buffers, &subscriptions, dm_handler); + let responder = DefaultResponder::new(&matter, buffers, &subscriptions, dm_handler); info!( "Responder memory: Responder={}B, Runner={}B", core::mem::size_of_val(&responder), @@ -161,8 +171,10 @@ fn run() -> Result<(), Error> { // NOTE: // Replace with your own persister for e.g. `no_std` environments - let mut psm = Psm::new(&matter, std::env::temp_dir().join("rs-matter"))?; - let mut persist = pin!(psm.run()); + + let psm = PSM.uninit().init_with(Psm::init()); + + let mut persist = pin!(psm.run(std::env::temp_dir().join("rs-matter"), &matter)); // Combine all async tasks in a single one let all = select4( diff --git a/examples/onoff_light_bt/src/comm.rs b/examples/onoff_light_bt/src/comm.rs index 6052b485..3c1f2135 100644 --- a/examples/onoff_light_bt/src/comm.rs +++ b/examples/onoff_light_bt/src/comm.rs @@ -32,7 +32,7 @@ use rs_matter::interaction_model::core::IMStatusCode; use rs_matter::interaction_model::messages::ib::Status; use rs_matter::tlv::{FromTLV, OctetStr, TLVElement}; use rs_matter::transport::exchange::Exchange; -use rs_matter::utils::notification::Notification; +use rs_matter::utils::sync::Notification; /// A _fake_ cluster implementing the Matter Network Commissioning Cluster /// for managing WiFi networks. diff --git a/examples/onoff_light_bt/src/main.rs b/examples/onoff_light_bt/src/main.rs index 4d6725ab..8099eb28 100644 --- a/examples/onoff_light_bt/src/main.rs +++ b/examples/onoff_light_bt/src/main.rs @@ -62,10 +62,9 @@ use rs_matter::respond::DefaultResponder; use rs_matter::secure_channel::spake2p::VerifierData; use rs_matter::transport::core::MATTER_SOCKET_BIND_ADDR; use rs_matter::transport::network::btp::{Btp, BtpContext}; -use rs_matter::utils::buf::PooledBuffers; -use rs_matter::utils::notification::Notification; use rs_matter::utils::select::Coalesce; -use rs_matter::utils::std_mutex::StdRawMutex; +use rs_matter::utils::storage::pooled::PooledBuffers; +use rs_matter::utils::sync::{blocking::raw::StdRawMutex, Notification}; use rs_matter::MATTER_PORT; mod comm; @@ -182,8 +181,8 @@ fn run() -> Result<(), Error> { // NOTE: // Replace with your own persister for e.g. `no_std` environments - let mut psm = Psm::new(&matter, std::env::temp_dir().join("rs-matter"))?; - let mut persist = pin!(psm.run()); + let mut psm: Psm<4096> = Psm::new(); + let mut persist = pin!(psm.run(std::env::temp_dir().join("rs-matter"), &matter)); if !matter.is_commissioned() { // Not commissioned yet, start commissioning first diff --git a/rs-matter/Cargo.toml b/rs-matter/Cargo.toml index 5f46190c..61548f9d 100644 --- a/rs-matter/Cargo.toml +++ b/rs-matter/Cargo.toml @@ -44,7 +44,8 @@ critical-section = "1.1" domain = { version = "0.10", default-features = false, features = ["heapless"] } portable-atomic = "1" qrcodegen-no-heap = "1.8" -scopeguard = "1" +scopeguard = { version = "1", default-features = false } +pinned-init = { version = "0.0.8", default-features = false } # crypto openssl = { version = "0.10", optional = true } @@ -81,8 +82,9 @@ tokio-stream = { version = "0.1" } [dev-dependencies] env_logger = "0.11" nix = { version = "0.27", features = ["net"] } -futures-lite = "1" +futures-lite = "2" async-channel = "2" +static_cell = "2" [[example]] name = "onoff_light" diff --git a/rs-matter/src/acl.rs b/rs-matter/src/acl.rs index b5715f65..5289b157 100644 --- a/rs-matter/src/acl.rs +++ b/rs-matter/src/acl.rs @@ -15,20 +15,22 @@ * limitations under the License. */ -use core::{cell::RefCell, fmt::Display, num::NonZeroU8}; - -use crate::{ - data_model::objects::{Access, ClusterId, EndptId, Privilege}, - error::{Error, ErrorCode}, - fabric, - interaction_model::messages::GenericPath, - tlv::{self, FromTLV, Nullable, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, - transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, - utils::writebuf::WriteBuf, -}; +use core::{fmt::Display, num::NonZeroU8}; + use log::error; + use num_derive::FromPrimitive; +use crate::data_model::objects::{Access, ClusterId, EndptId, Privilege}; +use crate::error::{Error, ErrorCode}; +use crate::fabric; +use crate::interaction_model::messages::GenericPath; +use crate::tlv::{self, FromTLV, Nullable, TLVElement, TLVList, TLVWriter, TagType, ToTLV}; +use crate::transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; +use crate::utils::storage::WriteBuf; + // Matter Minimum Requirements pub const SUBJECTS_PER_ENTRY: usize = 4; pub const TARGETS_PER_ENTRY: usize = 3; @@ -301,15 +303,17 @@ pub struct AclEntry { subjects: Subjects, targets: Targets, // TODO: Instead of the direct value, we should consider GlobalElements::FabricIndex + // Note that this field will always be `Some(NN)` when the entry is persisted in storage, + // however, it will be `None` when the entry is coming from the other peer #[tagval(0xFE)] - pub fab_idx: NonZeroU8, + pub fab_idx: Option, } impl AclEntry { pub fn new(fab_idx: NonZeroU8, privilege: Privilege, auth_mode: AuthMode) -> Self { const INIT_SUBJECTS: Option = None; Self { - fab_idx, + fab_idx: Some(fab_idx), privilege, auth_mode, subjects: [INIT_SUBJECTS; SUBJECTS_PER_ENTRY], @@ -368,7 +372,11 @@ impl AclEntry { } // true if both are true - allow && self.fab_idx.get() == accessor.fab_idx + allow + && self + .fab_idx + .map(|fab_idx| fab_idx.get() == accessor.fab_idx) + .unwrap_or(false) } fn match_access_desc(&self, object: &AccessDesc) -> bool { @@ -411,7 +419,7 @@ impl AclEntry { const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; -type AclEntries = heapless::Vec, MAX_ACL_ENTRIES>; +type AclEntries = crate::utils::storage::Vec, MAX_ACL_ENTRIES>; pub struct AclMgr { entries: AclEntries, @@ -425,6 +433,7 @@ impl Default for AclMgr { } impl AclMgr { + /// Create a new ACL Manager #[inline(always)] pub const fn new() -> Self { Self { @@ -433,6 +442,14 @@ impl AclMgr { } } + /// Return an in-place initializer for ACL Manager + pub fn init() -> impl Init { + init!(Self { + entries <- AclEntries::init(), + changed: false, + }) + } + pub fn erase_all(&mut self) -> Result<(), Error> { self.entries.clear(); self.changed = true; @@ -441,13 +458,18 @@ impl AclMgr { } pub fn add(&mut self, entry: AclEntry) -> Result { + let Some(fab_idx) = entry.fab_idx else { + // When persisting entries, the `fab_idx` should always be set + return Err(ErrorCode::Invalid.into()); + }; + if entry.auth_mode == AuthMode::Pase { // Reserved for future use // TODO: Should be something that results in IMStatusCode::ConstraintError Err(ErrorCode::Invalid)?; } - let cnt = self.get_index_in_fabric(MAX_ACL_ENTRIES, entry.fab_idx); + let cnt = self.get_index_in_fabric(MAX_ACL_ENTRIES, fab_idx); if cnt >= ENTRIES_PER_FABRIC as u8 { Err(ErrorCode::NoSpace)?; } @@ -455,8 +477,6 @@ impl AclMgr { let slot = self.entries.iter().position(|a| a.is_none()); if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES { - let fab_idx = entry.fab_idx; - let slot = if let Some(slot) = slot { self.entries[slot] = Some(entry); @@ -501,7 +521,7 @@ impl AclMgr { for entry in &mut self.entries { if entry .as_ref() - .map(|e| e.fab_idx == fab_idx) + .map(|e| e.fab_idx == Some(fab_idx)) .unwrap_or(false) { *entry = None; @@ -565,7 +585,7 @@ impl AclMgr { pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - tlv::from_tlv(&mut self.entries, &root)?; + tlv::vec_from_tlv(&mut self.entries, &root)?; self.changed = false; Ok(()) @@ -606,7 +626,11 @@ impl AclMgr { for (curr_index, entry) in self .entries .iter_mut() - .filter(|e| e.as_ref().filter(|e1| e1.fab_idx == fab_idx).is_some()) + .filter(|e| { + e.as_ref() + .filter(|e1| e1.fab_idx == Some(fab_idx)) + .is_some() + }) .enumerate() { if curr_index == index as usize { @@ -625,7 +649,7 @@ impl AclMgr { .iter() .take(till_slot_index) .flatten() - .filter(|e| e.fab_idx == fab_idx) + .filter(|e| e.fab_idx == Some(fab_idx)) .count() as u8 } } @@ -643,7 +667,7 @@ impl core::fmt::Display for AclMgr { #[cfg(test)] #[allow(clippy::bool_assert_comparison)] pub(crate) mod tests { - use core::{cell::RefCell, num::NonZeroU8}; + use core::num::NonZeroU8; use crate::{ acl::{gen_noc_cat, AccessorSubjects}, @@ -651,6 +675,8 @@ pub(crate) mod tests { interaction_model::messages::GenericPath, }; + use crate::utils::cell::RefCell; + use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target}; pub(crate) const FAB_1: NonZeroU8 = match NonZeroU8::new(1) { diff --git a/rs-matter/src/cert/mod.rs b/rs-matter/src/cert/mod.rs index c5c158a5..0bbf00fe 100644 --- a/rs-matter/src/cert/mod.rs +++ b/rs-matter/src/cert/mod.rs @@ -21,7 +21,7 @@ use crate::{ crypto::KeyPair, error::{Error, ErrorCode}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, - utils::{epoch::MATTER_CERT_DOESNT_EXPIRE, writebuf::WriteBuf}, + utils::{epoch::MATTER_CERT_DOESNT_EXPIRE, storage::WriteBuf}, }; use log::error; use num_derive::FromPrimitive; @@ -857,7 +857,7 @@ mod tests { use crate::cert::Cert; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; - use crate::utils::writebuf::WriteBuf; + use crate::utils::storage::WriteBuf; #[test] fn test_asn1_encode_success() { diff --git a/rs-matter/src/core.rs b/rs-matter/src/core.rs index a52c18ac..c68db6ce 100644 --- a/rs-matter/src/core.rs +++ b/rs-matter/src/core.rs @@ -15,27 +15,26 @@ * limitations under the License. */ -use core::cell::RefCell; - use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use crate::{ - acl::AclMgr, - data_model::{ - cluster_basic_information::BasicInfoConfig, - sdm::{dev_att::DevAttDataFetcher, failsafe::FailSafe}, - }, - error::*, - fabric::FabricMgr, - mdns::MdnsService, - pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, - secure_channel::{pake::PaseMgr, spake2p::VerifierData}, - transport::{ - core::{PacketBufferExternalAccess, TransportMgr}, - network::{NetworkReceive, NetworkSend}, - }, - utils::{buf::BufferAccess, epoch::Epoch, notification::Notification, rand::Rand}, +use crate::acl::AclMgr; +use crate::data_model::{ + cluster_basic_information::BasicInfoConfig, + sdm::{dev_att::DevAttDataFetcher, failsafe::FailSafe}, }; +use crate::error::*; +use crate::fabric::FabricMgr; +use crate::mdns::MdnsService; +use crate::pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}; +use crate::secure_channel::{pake::PaseMgr, spake2p::VerifierData}; +use crate::transport::core::{PacketBufferExternalAccess, TransportMgr}; +use crate::transport::network::{NetworkReceive, NetworkSend}; +use crate::utils::cell::RefCell; +use crate::utils::epoch::Epoch; +use crate::utils::init::{init, Init}; +use crate::utils::rand::Rand; +use crate::utils::storage::pooled::BufferAccess; +use crate::utils::sync::Notification; /* The Matter Port */ pub const MATTER_PORT: u16 = 5540; @@ -65,6 +64,16 @@ pub struct Matter<'a> { } impl<'a> Matter<'a> { + /// Create a new Matter object when support for the Rust Standard Library is enabled. + /// + /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. + /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device + /// requires a set of device attestation certificates and keys. It is the responsibility of + /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * port: The port number on which the Matter stack will listen for incoming connections. #[cfg(feature = "std")] #[inline(always)] pub const fn new_default( @@ -79,12 +88,19 @@ impl<'a> Matter<'a> { Self::new(dev_det, dev_att, mdns, sys_epoch, sys_rand, port) } - /// Creates a new Matter object + /// Create a new Matter object /// /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device /// requires a set of device attestation certificates and keys. It is the responsibility of /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * epoch: A function of type [Epoch]. This function is responsible for providing the current + /// "unix" time in milliseconds + /// * rand: A function of type [Rand]. This function is responsible for generating random data. + /// * port: The port number on which the Matter stack will listen for incoming connections. #[inline(always)] pub const fn new( dev_det: &'a BasicInfoConfig<'a>, @@ -99,7 +115,7 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), - transport_mgr: TransportMgr::new(mdns.new_impl(dev_det, port), epoch, rand), + transport_mgr: TransportMgr::new(mdns, dev_det, port, epoch, rand), persist_notification: Notification::new(), epoch, rand, @@ -109,6 +125,68 @@ impl<'a> Matter<'a> { } } + /// Create an in-place initializer for a Matter object + /// when support for the Rust Standard Library is enabled. + /// + /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. + /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device + /// requires a set of device attestation certificates and keys. It is the responsibility of + /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * port: The port number on which the Matter stack will listen for incoming connections. + #[cfg(feature = "std")] + pub fn init_default( + dev_det: &'a BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + mdns: MdnsService<'a>, + port: u16, + ) -> impl Init { + use crate::utils::epoch::sys_epoch; + use crate::utils::rand::sys_rand; + + Self::init(dev_det, dev_att, mdns, sys_epoch, sys_rand, port) + } + + /// Create an in-place initializer for a Matter object + /// + /// # Parameters + /// * dev_det: An object of type [BasicInfoConfig]. + /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device + /// requires a set of device attestation certificates and keys. It is the responsibility of + /// this object to return the device attestation details when queried upon. + /// * mdns: An object of type [MdnsService]. This object is responsible for handling mDNS + /// responses and queries related to the operation of the Matter stack. + /// * epoch: A function of type [Epoch]. This function is responsible for providing the current + /// "unix" time in milliseconds + /// * rand: A function of type [Rand]. This function is responsible for generating random data. + /// * port: The port number on which the Matter stack will listen for incoming connections. + pub fn init( + dev_det: &'a BasicInfoConfig<'a>, + dev_att: &'a dyn DevAttDataFetcher, + mdns: MdnsService<'a>, + epoch: Epoch, + rand: Rand, + port: u16, + ) -> impl Init { + init!( + Self { + fabric_mgr <- RefCell::init(FabricMgr::init()), + acl_mgr <- RefCell::init(AclMgr::init()), + pase_mgr <- RefCell::init(PaseMgr::init(epoch, rand)), + failsafe: RefCell::new(FailSafe::new()), + transport_mgr <- TransportMgr::init(mdns, dev_det, port, epoch, rand), + persist_notification: Notification::new(), + epoch, + rand, + dev_det, + dev_att, + port, + } + ) + } + pub fn initialize_transport_buffers(&self) -> Result<(), Error> { self.transport_mgr.initialize_buffers() } @@ -155,8 +233,7 @@ impl<'a> Matter<'a> { /// after that - while/if we still have exclusive, mutable access to the `Matter` object - /// replace the `MdnsService::Disabled` initial impl with another, like `MdnsService::Provided`. pub fn replace_mdns(&mut self, mdns: MdnsService<'a>) { - self.transport_mgr - .replace_mdns(mdns.new_impl(self.dev_det, self.port)); + self.transport_mgr.replace_mdns(mdns); } /// A utility method to replace the initial Device Attestation Data Fetcher with another one. diff --git a/rs-matter/src/data_model/core.rs b/rs-matter/src/data_model/core.rs index a3ebeb88..e17115f3 100644 --- a/rs-matter/src/data_model/core.rs +++ b/rs-matter/src/data_model/core.rs @@ -26,7 +26,7 @@ use embassy_time::{Instant, Timer}; use log::{debug, error, info, warn}; use crate::interaction_model::messages::ib::AttrStatus; -use crate::utils::buf::BufferAccess; +use crate::utils::storage::pooled::BufferAccess; use crate::{error::*, Matter}; use crate::interaction_model::core::{ @@ -39,7 +39,7 @@ use crate::interaction_model::messages::msg::{ use crate::respond::ExchangeHandler; use crate::tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType}; use crate::transport::exchange::{Exchange, MAX_EXCHANGE_RX_BUF_SIZE, MAX_EXCHANGE_TX_BUF_SIZE}; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; use super::objects::*; use super::subscriptions::Subscriptions; diff --git a/rs-matter/src/data_model/sdm/noc.rs b/rs-matter/src/data_model/sdm/noc.rs index 64e44f30..52793766 100644 --- a/rs-matter/src/data_model/sdm/noc.rs +++ b/rs-matter/src/data_model/sdm/noc.rs @@ -32,7 +32,7 @@ use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfSt use crate::transport::exchange::Exchange; use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; use crate::{attribute_enum, cmd_enter, command_enum, error::*}; use super::dev_att::{DataType, DevAttDataFetcher}; @@ -157,14 +157,14 @@ pub const CLUSTER: Cluster<'static> = Cluster { pub struct NocData { pub key_pair: KeyPair, - pub root_ca: heapless::Vec, + pub root_ca: crate::utils::storage::Vec, } impl NocData { pub fn new(key_pair: KeyPair) -> Self { Self { key_pair, - root_ca: heapless::Vec::new(), + root_ca: crate::utils::storage::Vec::new(), } } } @@ -327,15 +327,16 @@ impl NocCluster { let noc_cert = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; info!("Received NOC as: {}", noc_cert); - let noc = heapless::Vec::from_slice(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; + let noc = crate::utils::storage::Vec::from_slice(r.noc_value.0) + .map_err(|_| NocStatus::InvalidNOC)?; let icac = if let Some(icac_value) = r.icac_value { if !icac_value.0.is_empty() { let icac_cert = Cert::new(icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; info!("Received ICAC as: {}", icac_cert); - let icac = - heapless::Vec::from_slice(icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; + let icac = crate::utils::storage::Vec::from_slice(icac_value.0) + .map_err(|_| NocStatus::InvalidNOC)?; Some(icac) } else { None @@ -661,8 +662,8 @@ impl NocCluster { let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; info!("Received Trusted Cert:{:x?}", req.str); - noc_data.root_ca = - heapless::Vec::from_slice(req.str.0).map_err(|_| ErrorCode::BufferTooSmall)?; + noc_data.root_ca = crate::utils::storage::Vec::from_slice(req.str.0) + .map_err(|_| ErrorCode::BufferTooSmall)?; Ok(()) }) diff --git a/rs-matter/src/data_model/subscriptions.rs b/rs-matter/src/data_model/subscriptions.rs index e207c511..b6220d7f 100644 --- a/rs-matter/src/data_model/subscriptions.rs +++ b/rs-matter/src/data_model/subscriptions.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::cell::RefCell; use core::num::NonZeroU8; use embassy_sync::blocking_mutex::raw::NoopRawMutex; @@ -23,7 +22,9 @@ use embassy_time::Instant; use portable_atomic::{AtomicU32, Ordering}; -use crate::utils::notification::Notification; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; +use crate::utils::sync::Notification; struct Subscription { fabric_idx: NonZeroU8, @@ -62,7 +63,7 @@ impl Subscription { /// Additional subscriptions are rejected by the data model with a "resource exhausted" IM status message. pub struct Subscriptions { next_subscription_id: AtomicU32, - subscriptions: RefCell>, + subscriptions: RefCell>, pub(crate) notification: Notification, } @@ -78,11 +79,20 @@ impl Subscriptions { pub const fn new() -> Self { Self { next_subscription_id: AtomicU32::new(1), - subscriptions: RefCell::new(heapless::Vec::new()), + subscriptions: RefCell::new(crate::utils::storage::Vec::new()), notification: Notification::new(), } } + /// Create an in-place initializer for the instance. + pub fn init() -> impl Init { + init!(Self { + next_subscription_id: AtomicU32::new(1), + subscriptions <- RefCell::init(crate::utils::storage::Vec::init()), + notification: Notification::new(), + }) + } + /// Notify the instance that some data in the data model has changed and that it should re-evaluate the subscriptions /// and report on those that concern the changed data. /// diff --git a/rs-matter/src/data_model/system_model/access_control.rs b/rs-matter/src/data_model/system_model/access_control.rs index aa7daeb0..7ee58133 100644 --- a/rs-matter/src/data_model/system_model/access_control.rs +++ b/rs-matter/src/data_model/system_model/access_control.rs @@ -134,7 +134,12 @@ impl AccessControlCluster { Attributes::Acl(_) => { writer.start_array(AttrDataWriter::TAG)?; acl_mgr.for_each_acl(|entry| { - if !attr.fab_filter || attr.fab_idx == entry.fab_idx.get() { + if !attr.fab_filter + || entry + .fab_idx + .map(|fi| fi.get() == attr.fab_idx) + .unwrap_or(false) + { entry.to_tlv(&mut writer, TagType::Anonymous)?; } @@ -184,7 +189,7 @@ impl AccessControlCluster { let mut acl_entry = AclEntry::from_tlv(data)?; info!("ACL {:?}", acl_entry); // Overwrite the fabric index with our accessing fabric index - acl_entry.fab_idx = fab_idx; + acl_entry.fab_idx = Some(fab_idx); if let ListOperation::EditItem(index) = op { acl_mgr.edit(*index as u8, fab_idx, acl_entry)?; @@ -230,7 +235,7 @@ mod tests { use crate::data_model::system_model::access_control::Dataver; use crate::interaction_model::messages::ib::ListOperation; use crate::tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV}; - use crate::utils::writebuf::WriteBuf; + use crate::utils::storage::WriteBuf; use super::AccessControlCluster; diff --git a/rs-matter/src/fabric.rs b/rs-matter/src/fabric.rs index cf71a63c..f6d8ba19 100644 --- a/rs-matter/src/fabric.rs +++ b/rs-matter/src/fabric.rs @@ -19,18 +19,19 @@ use core::fmt::Write; use core::num::NonZeroU8; use byteorder::{BigEndian, ByteOrder, LittleEndian}; -use heapless::{String, Vec}; + +use heapless::String; + use log::info; -use crate::{ - cert::{Cert, MAX_CERT_TLV_LEN}, - crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, - error::{Error, ErrorCode}, - group_keys::KeySet, - mdns::{Mdns, ServiceMode}, - tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, - utils::writebuf::WriteBuf, -}; +use crate::cert::{Cert, MAX_CERT_TLV_LEN}; +use crate::crypto::{self, hkdf_sha256, HmacSha256, KeyPair}; +use crate::error::{Error, ErrorCode}; +use crate::group_keys::KeySet; +use crate::mdns::{Mdns, ServiceMode}; +use crate::tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::utils::init::{init, Init}; +use crate::utils::storage::{Vec, WriteBuf}; const COMPRESSED_FABRIC_ID_LEN: usize = 8; @@ -64,9 +65,9 @@ pub struct Fabric { impl Fabric { pub fn new( key_pair: KeyPair, - root_ca: heapless::Vec, - icac: Option>, - noc: heapless::Vec, + root_ca: Vec, + icac: Option>, + noc: Vec, ipk: &[u8], vendor_id: u16, label: &str, @@ -206,6 +207,13 @@ impl FabricMgr { } } + pub fn init() -> impl Init { + init!(Self { + fabrics <- FabricEntries::init(), + changed: false, + }) + } + pub fn load(&mut self, data: &[u8], mdns: &dyn Mdns) -> Result<(), Error> { for fabric in self.fabrics.iter().flatten() { mdns.remove(&fabric.mdns_service_name)?; @@ -213,7 +221,7 @@ impl FabricMgr { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - tlv::from_tlv(&mut self.fabrics, &root)?; + tlv::vec_from_tlv(&mut self.fabrics, &root)?; for fabric in self.fabrics.iter().flatten() { mdns.add(&fabric.mdns_service_name, ServiceMode::Commissioned)?; diff --git a/rs-matter/src/interaction_model/core.rs b/rs-matter/src/interaction_model/core.rs index 772b0411..7266e0a3 100644 --- a/rs-matter/src/interaction_model/core.rs +++ b/rs-matter/src/interaction_model/core.rs @@ -21,7 +21,7 @@ use crate::{ error::*, tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, transport::exchange::MessageMeta, - utils::{epoch::Epoch, writebuf::WriteBuf}, + utils::{epoch::Epoch, storage::WriteBuf}, }; use num::FromPrimitive; use num_derive::FromPrimitive; diff --git a/rs-matter/src/mdns.rs b/rs-matter/src/mdns.rs index a697275b..8852c1bb 100644 --- a/rs-matter/src/mdns.rs +++ b/rs-matter/src/mdns.rs @@ -17,7 +17,9 @@ use core::fmt::Write; -use crate::{data_model::cluster_basic_information::BasicInfoConfig, error::Error}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::Error; +use crate::utils::init::{init, Init}; #[cfg(all(feature = "std", target_os = "macos"))] #[path = "mdns/astro.rs"] @@ -98,48 +100,66 @@ pub enum MdnsService<'a> { Provided(&'a dyn Mdns), } -impl<'a> MdnsService<'a> { - pub(crate) const fn new_impl( - &self, +pub(crate) struct MdnsImpl<'a> { + service: MdnsService<'a>, + builtin: builtin::MdnsImpl<'a>, +} + +impl<'a> MdnsImpl<'a> { + pub(crate) const fn new( + service: MdnsService<'a>, dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, - ) -> MdnsImpl<'a> { - match self { - Self::Disabled => MdnsImpl::Disabled, - Self::Builtin => MdnsImpl::Builtin(builtin::MdnsImpl::new(dev_det, matter_port)), - Self::Provided(mdns) => MdnsImpl::Provided(*mdns), + ) -> Self { + Self { + service, + builtin: builtin::MdnsImpl::new(dev_det, matter_port), } } -} -pub(crate) enum MdnsImpl<'a> { - Disabled, - Builtin(builtin::MdnsImpl<'a>), - Provided(&'a dyn Mdns), + pub(crate) fn init( + service: MdnsService<'a>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + ) -> impl Init { + init!(Self { + service, + builtin <- builtin::MdnsImpl::init(dev_det, matter_port), + }) + } + + #[allow(unused)] + pub(crate) fn builtin(&self) -> Option<&builtin::MdnsImpl> { + matches!(self.service, MdnsService::Builtin).then_some(&self.builtin) + } + + pub(crate) fn update(&mut self, service: MdnsService<'a>) { + self.service = service; + } } impl<'a> Mdns for MdnsImpl<'a> { fn reset(&self) { - match self { - Self::Disabled => {} - Self::Builtin(mdns) => mdns.reset(), - Self::Provided(mdns) => mdns.reset(), + match self.service { + MdnsService::Disabled => {} + MdnsService::Builtin => self.builtin.reset(), + MdnsService::Provided(mdns) => mdns.reset(), } } fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { - match self { - Self::Disabled => Ok(()), - Self::Builtin(mdns) => mdns.add(service, mode), - Self::Provided(mdns) => mdns.add(service, mode), + match self.service { + MdnsService::Disabled => Ok(()), + MdnsService::Builtin => self.builtin.add(service, mode), + MdnsService::Provided(mdns) => mdns.add(service, mode), } } fn remove(&self, service: &str) -> Result<(), Error> { - match self { - Self::Disabled => Ok(()), - Self::Builtin(mdns) => mdns.remove(service), - Self::Provided(mdns) => mdns.remove(service), + match self.service { + MdnsService::Disabled => Ok(()), + MdnsService::Builtin => self.builtin.remove(service), + MdnsService::Provided(mdns) => mdns.remove(service), } } } diff --git a/rs-matter/src/mdns/astro.rs b/rs-matter/src/mdns/astro.rs index 06c67eb8..413e0614 100644 --- a/rs-matter/src/mdns/astro.rs +++ b/rs-matter/src/mdns/astro.rs @@ -1,15 +1,13 @@ -use core::cell::RefCell; - use std::collections::BTreeMap; use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; use log::info; -use crate::{ - data_model::cluster_basic_information::BasicInfoConfig, - error::{Error, ErrorCode}, -}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::{Error, ErrorCode}; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; use super::ServiceMode; @@ -28,6 +26,14 @@ impl<'a> MdnsImpl<'a> { } } + pub fn init(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> impl Init { + init!(Self { + dev_det, + matter_port, + services <- RefCell::init(BTreeMap::new()), + }) + } + pub fn reset(&self) { self.services.borrow_mut().clear(); } diff --git a/rs-matter/src/mdns/builtin.rs b/rs-matter/src/mdns/builtin.rs index 235f1bc3..89207e6e 100644 --- a/rs-matter/src/mdns/builtin.rs +++ b/rs-matter/src/mdns/builtin.rs @@ -1,4 +1,3 @@ -use core::cell::RefCell; use core::net::IpAddr; use core::pin::pin; @@ -6,6 +5,7 @@ use embassy_futures::select::select; use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; use embassy_sync::mutex::Mutex; use embassy_time::{Duration, Timer}; + use log::{info, warn}; use crate::data_model::cluster_basic_information::BasicInfoConfig; @@ -14,8 +14,12 @@ use crate::transport::network::{ Address, Ipv4Addr, Ipv6Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV4, SocketAddrV6, }; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; use crate::utils::rand::Rand; -use crate::utils::{buf::BufferAccess, notification::Notification, select::Coalesce}; +use crate::utils::select::Coalesce; +use crate::utils::storage::pooled::BufferAccess; +use crate::utils::sync::Notification; use super::{Service, ServiceMode}; @@ -37,7 +41,7 @@ pub const MDNS_PORT: u16 = 5353; pub struct MdnsImpl<'a> { dev_det: &'a BasicInfoConfig<'a>, matter_port: u16, - services: RefCell, ServiceMode), 4>>, + services: RefCell, ServiceMode), 4>>, notification: Notification, } @@ -47,11 +51,20 @@ impl<'a> MdnsImpl<'a> { Self { dev_det, matter_port, - services: RefCell::new(heapless::Vec::new()), + services: RefCell::new(crate::utils::storage::Vec::new()), notification: Notification::new(), } } + pub fn init(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> impl Init { + init!(Self { + dev_det, + matter_port, + services <- RefCell::init(crate::utils::storage::Vec::init()), + notification: Notification::new(), + }) + } + pub fn reset(&self) { self.services.borrow_mut().clear(); } @@ -136,11 +149,11 @@ impl<'a> MdnsImpl<'a> { select(&mut notification, &mut timeout).await; - for addr in core::iter::once(SocketAddr::V4(SocketAddrV4::new( - MDNS_IPV4_BROADCAST_ADDR, - MDNS_PORT, - ))) - .chain( + for addr in Iterator::chain( + core::iter::once(SocketAddr::V4(SocketAddrV4::new( + MDNS_IPV4_BROADCAST_ADDR, + MDNS_PORT, + ))), interface .map(|interface| { SocketAddr::V6(SocketAddrV6::new( diff --git a/rs-matter/src/mdns/proto.rs b/rs-matter/src/mdns/proto.rs index bd2ed06b..2ba952dd 100644 --- a/rs-matter/src/mdns/proto.rs +++ b/rs-matter/src/mdns/proto.rs @@ -1064,7 +1064,7 @@ mod tests { assert_eq!(t, str); } - while let Some(t) = txt.next() { + for t in txt { if !t.is_empty() { panic!("Unexpected TXT string {:?} for {}", t, expected.owner); } diff --git a/rs-matter/src/mdns/zeroconf.rs b/rs-matter/src/mdns/zeroconf.rs index 5a04ea04..db5b45d1 100644 --- a/rs-matter/src/mdns/zeroconf.rs +++ b/rs-matter/src/mdns/zeroconf.rs @@ -1,5 +1,3 @@ -use core::cell::RefCell; - use std::collections::BTreeMap; use std::sync::mpsc::{sync_channel, SyncSender}; @@ -7,10 +5,10 @@ use log::error; use zeroconf::{prelude::TEventLoop, service::TMdnsService, txt_record::TTxtRecord, ServiceType}; -use crate::{ - data_model::cluster_basic_information::BasicInfoConfig, - error::{Error, ErrorCode}, -}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::{Error, ErrorCode}; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; use super::ServiceMode; @@ -39,6 +37,14 @@ impl<'a> MdnsImpl<'a> { } } + pub fn init(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> impl Init { + init!(Self { + dev_det, + matter_port, + services <- RefCell::init(BTreeMap::new()), + }) + } + pub fn reset(&self) { self.services.borrow_mut().clear(); } diff --git a/rs-matter/src/pairing/qr.rs b/rs-matter/src/pairing/qr.rs index a5c08650..44c14af1 100644 --- a/rs-matter/src/pairing/qr.rs +++ b/rs-matter/src/pairing/qr.rs @@ -20,7 +20,7 @@ use qrcodegen_no_heap::{QrCode, QrCodeEcc, Version}; use crate::{ error::ErrorCode, tlv::{ElementType, TLVElement, TLVWriter, TagType, ToTLV}, - utils::writebuf::WriteBuf, + utils::storage::WriteBuf, }; use super::{ diff --git a/rs-matter/src/persist.rs b/rs-matter/src/persist.rs index a25b13a0..4936c544 100644 --- a/rs-matter/src/persist.rs +++ b/rs-matter/src/persist.rs @@ -19,58 +19,96 @@ pub use fileio::*; #[cfg(feature = "std")] pub mod fileio { + use core::mem::MaybeUninit; + use std::fs; use std::io::{Read, Write}; - use std::path::{Path, PathBuf}; + use std::path::Path; use log::info; use crate::error::{Error, ErrorCode}; + use crate::utils::init::{init, Init}; use crate::Matter; - pub struct Psm<'a> { - matter: &'a Matter<'a>, - dir: PathBuf, - buf: [u8; 4096], + pub struct Psm { + buf: MaybeUninit<[u8; N]>, + } + + impl Default for Psm { + fn default() -> Self { + Self::new() + } } - impl<'a> Psm<'a> { + impl Psm { #[inline(always)] - pub fn new(matter: &'a Matter<'a>, dir: PathBuf) -> Result { - fs::create_dir_all(&dir)?; + pub const fn new() -> Self { + Self { + buf: MaybeUninit::uninit(), + } + } - info!("Persisting from/to {}", dir.display()); + pub fn init() -> impl Init { + init!(Self { + buf <- crate::utils::init::zeroed(), + }) + } - let mut buf = [0; 4096]; + pub fn load(&mut self, dir: &Path, matter: &Matter) -> Result<(), Error> { + fs::create_dir_all(dir)?; - if let Some(data) = Self::load(&dir, "acls", &mut buf)? { + if let Some(data) = Self::load_key(dir, "acls", unsafe { self.buf.assume_init_mut() })? + { matter.load_acls(data)?; } - if let Some(data) = Self::load(&dir, "fabrics", &mut buf)? { + if let Some(data) = + Self::load_key(dir, "fabrics", unsafe { self.buf.assume_init_mut() })? + { matter.load_fabrics(data)?; } - Ok(Self { matter, dir, buf }) + Ok(()) } - pub async fn run(&mut self) -> Result<(), Error> { - loop { - self.matter.wait_changed().await; + pub fn store(&mut self, dir: &Path, matter: &Matter) -> Result<(), Error> { + if matter.is_changed() { + fs::create_dir_all(dir)?; - if self.matter.is_changed() { - if let Some(data) = self.matter.store_acls(&mut self.buf)? { - Self::store(&self.dir, "acls", data)?; - } + if let Some(data) = matter.store_acls(unsafe { self.buf.assume_init_mut() })? { + Self::store_key(dir, "acls", data)?; + } - if let Some(data) = self.matter.store_fabrics(&mut self.buf)? { - Self::store(&self.dir, "fabrics", data)?; - } + if let Some(data) = matter.store_fabrics(unsafe { self.buf.assume_init_mut() })? { + Self::store_key(dir, "fabrics", data)?; } } + + Ok(()) + } + + pub async fn run>( + &mut self, + dir: P, + matter: &Matter<'_>, + ) -> Result<(), Error> { + let dir = dir.as_ref(); + + self.load(dir, matter)?; + + loop { + matter.wait_changed().await; + + self.store(dir, matter)?; + } } - fn load<'b>(dir: &Path, key: &str, buf: &'b mut [u8]) -> Result, Error> { + fn load_key<'b>( + dir: &Path, + key: &str, + buf: &'b mut [u8], + ) -> Result, Error> { let path = dir.join(key); match fs::File::open(path) { @@ -101,7 +139,7 @@ pub mod fileio { } } - fn store(dir: &Path, key: &str, data: &[u8]) -> Result<(), Error> { + fn store_key(dir: &Path, key: &str, data: &[u8]) -> Result<(), Error> { let path = dir.join(key); let mut file = fs::File::create(path)?; diff --git a/rs-matter/src/respond.rs b/rs-matter/src/respond.rs index fd56fe7b..2abf6f46 100644 --- a/rs-matter/src/respond.rs +++ b/rs-matter/src/respond.rs @@ -31,8 +31,8 @@ use crate::interaction_model::core::PROTO_ID_INTERACTION_MODEL; use crate::secure_channel::busy::BusySecureChannel; use crate::secure_channel::core::SecureChannel; use crate::transport::exchange::Exchange; -use crate::utils::buf::BufferAccess; use crate::utils::select::Coalesce; +use crate::utils::storage::pooled::BufferAccess; use crate::Matter; /// Send a busy response if - after that many ms - the exchange diff --git a/rs-matter/src/secure_channel/case.rs b/rs-matter/src/secure_channel/case.rs index 1ecfd6cd..e1c4fe7c 100644 --- a/rs-matter/src/secure_channel/case.rs +++ b/rs-matter/src/secure_channel/case.rs @@ -31,7 +31,7 @@ use crate::{ exchange::Exchange, session::{NocCatIds, ReservedSession, SessionMode}, }, - utils::{rand::Rand, writebuf::WriteBuf}, + utils::{rand::Rand, storage::WriteBuf}, }; #[derive(Debug, Clone)] diff --git a/rs-matter/src/secure_channel/common.rs b/rs-matter/src/secure_channel/common.rs index 4d4d3921..8b559153 100644 --- a/rs-matter/src/secure_channel/common.rs +++ b/rs-matter/src/secure_channel/common.rs @@ -19,7 +19,7 @@ use num_derive::FromPrimitive; use crate::error::Error; use crate::transport::exchange::{Exchange, MessageMeta}; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; use super::status_report::{GeneralCode, StatusReport}; diff --git a/rs-matter/src/secure_channel/pake.rs b/rs-matter/src/secure_channel/pake.rs index d00434c9..9e7297fa 100644 --- a/rs-matter/src/secure_channel/pake.rs +++ b/rs-matter/src/secure_channel/pake.rs @@ -17,23 +17,25 @@ use core::{fmt::Write, time::Duration}; -use super::{ - common::SCStatusCodes, - spake2p::{Spake2P, VerifierData, MAX_SALT_SIZE_BYTES}, +use log::{error, info}; + +use crate::crypto; +use crate::error::{Error, ErrorCode}; +use crate::mdns::{Mdns, ServiceMode}; +use crate::secure_channel::common::{complete_with_status, OpCode}; +use crate::tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; +use crate::transport::{ + exchange::{Exchange, ExchangeId}, + session::{ReservedSession, SessionMode}, }; -use crate::{ - crypto, - error::{Error, ErrorCode}, - mdns::{Mdns, ServiceMode}, - secure_channel::common::{complete_with_status, OpCode}, - tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, - transport::{ - exchange::{Exchange, ExchangeId}, - session::{ReservedSession, SessionMode}, - }, - utils::{epoch::Epoch, rand::Rand}, +use crate::utils::{ + epoch::Epoch, + init::{init, Init}, + rand::Rand, }; -use log::{error, info}; + +use super::common::SCStatusCodes; +use super::spake2p::{Spake2P, VerifierData, MAX_SALT_SIZE_BYTES}; struct PaseSession { mdns_service_name: heapless::String<16>, @@ -58,6 +60,17 @@ impl PaseMgr { } } + pub fn init(epoch: Epoch, rand: Rand) -> impl Init { + // TODO: Optimize in future because `PaseSession` is + // relatively large and we are creating it using stack moves. + init!(Self { + session: None, + timeout: None, + epoch, + rand, + }) + } + pub fn is_pase_session_enabled(&self) -> bool { self.session.is_some() } diff --git a/rs-matter/src/secure_channel/status_report.rs b/rs-matter/src/secure_channel/status_report.rs index d365dd1a..bb433a17 100644 --- a/rs-matter/src/secure_channel/status_report.rs +++ b/rs-matter/src/secure_channel/status_report.rs @@ -19,7 +19,7 @@ use num_derive::FromPrimitive; use crate::{ error::{Error, ErrorCode}, - utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, + utils::storage::{ParseBuf, WriteBuf}, }; #[allow(dead_code)] diff --git a/rs-matter/src/tlv/traits.rs b/rs-matter/src/tlv/traits.rs index acb47c01..2839fb06 100644 --- a/rs-matter/src/tlv/traits.rs +++ b/rs-matter/src/tlv/traits.rs @@ -79,6 +79,24 @@ pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>( Ok(()) } +pub fn vec_from_tlv<'a, T: FromTLV<'a>, const N: usize>( + vec: &mut crate::utils::storage::Vec, + t: &TLVElement<'a>, +) -> Result<(), Error> { + vec.clear(); + + t.confirm_array()?; + + if let Some(tlv_iter) = t.enter() { + for element in tlv_iter { + vec.push(T::from_tlv(&element)?) + .map_err(|_| ErrorCode::NoSpace)?; + } + } + + Ok(()) +} + macro_rules! fromtlv_for { ($($t:ident)*) => { $( @@ -237,6 +255,19 @@ impl ToTLV for heapless::Vec { } } +/// Implements the Owned version of Octet String +impl FromTLV<'_> for crate::utils::storage::Vec { + fn from_tlv(t: &TLVElement) -> Result, Error> { + crate::utils::storage::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into()) + } +} + +impl ToTLV for crate::utils::storage::Vec { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.str16(tag, self.as_slice()) + } +} + /// Implements the Owned version of UTF String impl FromTLV<'_> for heapless::String { fn from_tlv(t: &TLVElement) -> Result, Error> { @@ -538,7 +569,8 @@ macro_rules! bitflags_tlv { #[cfg(test)] mod tests { use super::{FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; - use crate::{tlv::TLVList, utils::writebuf::WriteBuf}; + use crate::tlv::TLVList; + use crate::utils::storage::WriteBuf; use rs_matter_macros::{FromTLV, ToTLV}; #[derive(ToTLV)] diff --git a/rs-matter/src/tlv/writer.rs b/rs-matter/src/tlv/writer.rs index 45c60c97..e6af4efe 100644 --- a/rs-matter/src/tlv/writer.rs +++ b/rs-matter/src/tlv/writer.rs @@ -15,10 +15,13 @@ * limitations under the License. */ -use super::{TagType, TAG_SHIFT_BITS, TAG_SIZE_MAP}; -use crate::{error::*, utils::writebuf::WriteBuf}; use log::error; +use crate::error::*; +use crate::utils::storage::WriteBuf; + +use super::{TagType, TAG_SHIFT_BITS, TAG_SIZE_MAP}; + #[allow(dead_code)] enum WriteElementType { S8 = 0, @@ -273,7 +276,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> { #[cfg(test)] mod tests { use super::{TLVWriter, TagType}; - use crate::utils::writebuf::WriteBuf; + use crate::utils::storage::WriteBuf; #[test] fn test_write_success() { diff --git a/rs-matter/src/transport/core.rs b/rs-matter/src/transport/core.rs index 3559e35e..b1e386bd 100644 --- a/rs-matter/src/transport/core.rs +++ b/rs-matter/src/transport/core.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::cell::RefCell; use core::fmt::{self, Display}; use core::ops::{Deref, DerefMut}; use core::pin::pin; @@ -26,21 +25,19 @@ use embassy_time::Timer; use log::{debug, error, info, trace, warn}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; use crate::error::{Error, ErrorCode}; -use crate::mdns::MdnsImpl; +use crate::mdns::{MdnsImpl, MdnsService}; use crate::secure_channel::common::{sc_write, OpCode, SCStatusCodes, PROTO_ID_SECURE_CHANNEL}; use crate::secure_channel::status_report::StatusReport; use crate::tlv::TLVList; -use crate::utils::buf::BufferAccess; -use crate::utils::{ - epoch::Epoch, - ifmutex::{IfMutex, IfMutexGuard}, - notification::Notification, - parsebuf::ParseBuf, - rand::Rand, - select::Coalesce, - writebuf::WriteBuf, -}; +use crate::utils::cell::RefCell; +use crate::utils::epoch::Epoch; +use crate::utils::init::{init, Init}; +use crate::utils::rand::Rand; +use crate::utils::select::Coalesce; +use crate::utils::storage::{pooled::BufferAccess, ParseBuf, WriteBuf}; +use crate::utils::sync::{IfMutex, IfMutexGuard, Notification}; use crate::{Matter, MATTER_PORT}; use super::exchange::{Exchange, ExchangeId, ExchangeState, MessageMeta, ResponderState, Role}; @@ -86,33 +83,58 @@ pub struct TransportMgr<'m> { impl<'m> TransportMgr<'m> { #[inline(always)] - pub(crate) const fn new(mdns: MdnsImpl<'m>, epoch: Epoch, rand: Rand) -> Self { + pub(crate) const fn new( + service: MdnsService<'m>, + dev_det: &'m BasicInfoConfig<'m>, + matter_port: u16, + epoch: Epoch, + rand: Rand, + ) -> Self { Self { rx: IfMutex::new(Packet::new()), tx: IfMutex::new(Packet::new()), dropped: Notification::new(), session_removed: Notification::new(), session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), - mdns, + mdns: MdnsImpl::new(service, dev_det, matter_port), rand, } } - pub(crate) fn replace_mdns(&mut self, mdns: MdnsImpl<'m>) { - self.mdns = mdns; + pub(crate) fn init( + service: MdnsService<'m>, + dev_det: &'m BasicInfoConfig<'m>, + matter_port: u16, + epoch: Epoch, + rand: Rand, + ) -> impl Init { + init!(Self { + rx <- IfMutex::init(Packet::init()), + tx <- IfMutex::init(Packet::init()), + dropped: Notification::new(), + session_removed: Notification::new(), + session_mgr <- RefCell::init(SessionMgr::init(epoch, rand)), + mdns <- MdnsImpl::init(service, dev_det, matter_port), + rand, + }) } + pub(crate) fn replace_mdns(&mut self, mdns: MdnsService<'m>) { + self.mdns.update(mdns); + } + + // TODO #[cfg(all(feature = "large-buffers", feature = "alloc"))] pub fn initialize_buffers(&self) -> Result<(), Error> { let mut rx = self.rx.try_lock().map_err(|_| ErrorCode::InvalidState)?; let mut tx = self.tx.try_lock().map_err(|_| ErrorCode::InvalidState)?; if rx.buf.0.is_none() { - rx.buf.0 = Some(alloc::boxed::Box::new(heapless::Vec::new())); + rx.buf.0 = Some(alloc::boxed::Box::new(crate::utils::storage::Vec::new())); } if tx.buf.0.is_none() { - tx.buf.0 = Some(alloc::boxed::Box::new(heapless::Vec::new())); + tx.buf.0 = Some(alloc::boxed::Box::new(crate::utils::storage::Vec::new())); } Ok(()) @@ -274,7 +296,7 @@ impl<'m> TransportMgr<'m> { { info!("Running Matter built-in mDNS service"); - if let MdnsImpl::Builtin(mdns) = &self.mdns { + if let Some(mdns) = self.mdns.builtin() { mdns.run( send, recv, @@ -1046,6 +1068,15 @@ impl Packet { } } + pub(crate) fn init() -> impl Init { + init!(Self { + peer: Address::new(), + header: PacketHdr::new(), + buf <- PacketBuffer::init(), + payload_start: 0, + }) + } + pub fn display<'a>(peer: &'a Address, header: &'a PacketHdr) -> impl Display + 'a { struct PacketInfo<'a>(&'a Address, &'a PacketHdr); @@ -1116,54 +1147,66 @@ impl Display for Packet { // // This type is only known and used by `TransportMgr` and the `exchange` module #[cfg(all(feature = "large-buffers", feature = "alloc"))] -pub(crate) struct PacketBuffer(Option>>); +pub(crate) struct PacketBuffer( + Option>>, +); // The buffer used inside the pair of RX and TX `Packet` instances // When the either of the `alloc` and `large-buffers` features is not enabled, the buffer payload is allocated inline // // This type is only known and used by `TransportMgr` and the `exchange` module #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] -pub(crate) struct PacketBuffer(heapless::Vec); +pub(crate) struct PacketBuffer { + buffer: crate::utils::storage::Vec, +} impl PacketBuffer { #[cfg(all(feature = "large-buffers", feature = "alloc"))] pub const fn new() -> Self { - Self(None) + Self { buffer: None } } #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] pub const fn new() -> Self { - Self(heapless::Vec::new()) + Self { + buffer: crate::utils::storage::Vec::new(), + } + } + + pub fn init() -> impl Init { + init!(Self { + buffer <- crate::utils::storage::Vec::init(), + }) } #[cfg(all(feature = "large-buffers", feature = "alloc"))] - pub fn buf_mut(&mut self) -> &mut heapless::Vec { + pub fn buf_mut(&mut self) -> &mut crate::utils::storage::Vec { &mut *self - .0 + .buffer .as_mut() .expect("Buffer is not allocated. Did you forget to call `initialize_buffers`?") } #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] - pub fn buf_mut(&mut self) -> &mut heapless::Vec { - &mut self.0 + pub fn buf_mut(&mut self) -> &mut crate::utils::storage::Vec { + &mut self.buffer } #[cfg(all(feature = "large-buffers", feature = "alloc"))] - pub fn buf_ref(&self) -> &heapless::Vec { - self.0 + pub fn buf_ref(&self) -> crate::utils::storage::Vec { + self.buffer .as_ref() .expect("Buffer is not allocated. Did you forget to call `initialize_buffers`?") } #[cfg(not(all(feature = "large-buffers", feature = "alloc")))] - pub fn buf_ref(&self) -> &heapless::Vec { - &self.0 + pub fn buf_ref(&self) -> &crate::utils::storage::Vec { + &self.buffer } } impl Deref for PacketBuffer { - type Target = heapless::Vec; + type Target = crate::utils::storage::Vec; fn deref(&self) -> &Self::Target { self.buf_ref() diff --git a/rs-matter/src/transport/exchange.rs b/rs-matter/src/transport/exchange.rs index 4d800251..5ec3d10a 100644 --- a/rs-matter/src/transport/exchange.rs +++ b/rs-matter/src/transport/exchange.rs @@ -27,7 +27,8 @@ use crate::acl::Accessor; use crate::error::{Error, ErrorCode}; use crate::interaction_model::{self, core::PROTO_ID_INTERACTION_MODEL}; use crate::secure_channel::{self, common::PROTO_ID_SECURE_CHANNEL}; -use crate::utils::{epoch::Epoch, writebuf::WriteBuf}; +use crate::utils::epoch::Epoch; +use crate::utils::storage::WriteBuf; use crate::Matter; use super::core::{Packet, PacketAccess, MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; diff --git a/rs-matter/src/transport/network/btp.rs b/rs-matter/src/transport/network/btp.rs index 06724ba6..23cd5dc4 100644 --- a/rs-matter/src/transport/network/btp.rs +++ b/rs-matter/src/transport/network/btp.rs @@ -27,13 +27,15 @@ use embassy_time::{Duration, Instant, Timer}; use log::trace; use context::LockError; + use session::{BTP_ACK_TIMEOUT_SECS, BTP_CONN_IDLE_TIMEOUT_SECS}; use crate::data_model::cluster_basic_information::BasicInfoConfig; use crate::error::{Error, ErrorCode}; use crate::transport::network::{Address, BtAddr, NetworkReceive, NetworkSend}; -use crate::utils::ifmutex::IfMutex; +use crate::utils::init::{init, Init}; use crate::utils::select::Coalesce; +use crate::utils::sync::IfMutex; use crate::CommissioningData; pub use context::{BtpContext, MAX_BTP_SESSIONS}; @@ -65,7 +67,7 @@ pub(crate) const MAX_MTU: u16 = (MAX_BTP_SEGMENT_SIZE + GATT_HEADER_SIZE) as u16 pub struct Btp { gatt: T, context: C, - send_buf: IfMutex>, + send_buf: IfMutex>, ack_timeout_secs: u16, conn_idle_timeout_secs: u16, _mutex: PhantomData, @@ -81,6 +83,10 @@ where pub fn new_builtin(context: C) -> Self { Self::new(BuiltinGattPeripheral::new(None), context) } + + pub fn init_builtin>(context: C) -> impl Init { + Self::init(BuiltinGattPeripheral::new(None), context) + } } impl Btp @@ -111,13 +117,27 @@ where Self { gatt, context, - send_buf: IfMutex::new(heapless::Vec::new()), + send_buf: IfMutex::new(crate::utils::storage::Vec::new()), ack_timeout_secs, conn_idle_timeout_secs, _mutex: PhantomData, } } + /// Create an in-place initializer for a BTP object with the provided + /// `GattPeripheral` trait in-place initializer and and with the provided BTP + /// `context` in-place initializer. + pub fn init, IC: Init>(gatt: IT, context: IC) -> impl Init { + init!(Self { + gatt <- gatt, + context <- context, + send_buf <- IfMutex::init(crate::utils::storage::Vec::init()), + ack_timeout_secs: BTP_ACK_TIMEOUT_SECS, + conn_idle_timeout_secs: BTP_CONN_IDLE_TIMEOUT_SECS, + _mutex: PhantomData, + }) + } + /// Run the BTP protocol /// /// While all sending and receiving of Matter packets (a.k.a. BTP SDUs) is done via the `recv` and `send` methods @@ -291,7 +311,7 @@ where /// in case it is used by another operation. async fn send_buf( &self, - ) -> impl DerefMut> + '_ { + ) -> impl DerefMut> + '_ { let mut buf = self.send_buf.lock().await; // Unwrap is safe because the max size of the buffer is MAX_PDU_SIZE diff --git a/rs-matter/src/transport/network/btp/context.rs b/rs-matter/src/transport/network/btp/context.rs index 117439e9..8634dff9 100644 --- a/rs-matter/src/transport/network/btp/context.rs +++ b/rs-matter/src/transport/network/btp/context.rs @@ -15,14 +15,16 @@ * limitations under the License. */ -use core::cell::RefCell; +use embassy_sync::blocking_mutex::raw::RawMutex; -use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; use log::{error, info, trace, warn}; use crate::error::{Error, ErrorCode}; use crate::transport::network::BtAddr; -use crate::utils::notification::Notification; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init, IntoFallibleInit}; +use crate::utils::sync::blocking::Mutex; +use crate::utils::sync::Notification; use super::{session::Session, GattPeripheralEvent}; @@ -182,7 +184,7 @@ pub struct BtpContext where M: RawMutex, { - pub(crate) sessions: Mutex>>, + pub(crate) sessions: Mutex>>, pub(crate) handshake_notif: Notification, pub(crate) available_notif: Notification, pub(crate) recv_notif: Notification, @@ -207,7 +209,7 @@ where #[inline(always)] pub const fn new() -> Self { Self { - sessions: Mutex::new(RefCell::new(heapless::Vec::new())), + sessions: Mutex::new(RefCell::new(crate::utils::storage::Vec::new())), handshake_notif: Notification::new(), available_notif: Notification::new(), recv_notif: Notification::new(), @@ -215,6 +217,18 @@ where send_notif: Notification::new(), } } + + /// Create a BTP context in-place initializer. + pub fn init() -> impl Init { + init!(Self { + sessions <- Mutex::init(RefCell::init(crate::utils::storage::Vec::init())), + handshake_notif: Notification::new(), + available_notif: Notification::new(), + recv_notif: Notification::new(), + ack_notif: Notification::new(), + send_notif: Notification::new(), + }) + } } impl BtpContext @@ -251,9 +265,11 @@ where warn!("Too many BTP sessions, dropping a handshake request from address {address}"); } else { // Unwrap is safe because we checked the length above - sessions - .push(Session::process_rx_handshake(address, data, gatt_mtu)?) - .unwrap(); + sessions.push_init( + Session::process_rx_handshake(address, data, gatt_mtu)?.into_fallible::(), + || ErrorCode::NoSpace.into(), + ) + .unwrap(); } Ok(()) diff --git a/rs-matter/src/transport/network/btp/gatt/bluer.rs b/rs-matter/src/transport/network/btp/gatt/bluer.rs index ffac0740..525a1e7c 100644 --- a/rs-matter/src/transport/network/btp/gatt/bluer.rs +++ b/rs-matter/src/transport/network/btp/gatt/bluer.rs @@ -16,6 +16,7 @@ */ use core::iter::once; +use core::ptr::addr_of_mut; use alloc::sync::Arc; @@ -36,12 +37,14 @@ use log::{info, trace, warn}; use tokio::io::AsyncWriteExt; use tokio_stream::StreamExt; +use crate::error::{Error, ErrorCode}; use crate::transport::network::btp::MIN_MTU; -use crate::{ - error::{Error, ErrorCode}, - transport::network::{btp::context::MAX_BTP_SESSIONS, BtAddr}, - utils::{ifmutex::IfMutex, select::Coalesce, signal::Signal, std_mutex::StdRawMutex}, -}; +use crate::transport::network::{btp::context::MAX_BTP_SESSIONS, BtAddr}; +use crate::utils::init::{init_from_closure, Init}; +use crate::utils::select::Coalesce; +use crate::utils::sync::blocking::raw::StdRawMutex; +use crate::utils::sync::IfMutex; +use crate::utils::sync::Signal; use super::{AdvData, GattPeripheral, GattPeripheralEvent}; use super::{C1_CHARACTERISTIC_UUID, C2_CHARACTERISTIC_UUID, MATTER_BLE_SERVICE_UUID}; @@ -61,6 +64,16 @@ struct GattState { notifiers_listen_allowed: Signal, } +impl GattState { + pub fn new(adapter_name: Option<&str>) -> Self { + Self { + adapter_name: adapter_name.map(|name| name.into()), + notifiers: IfMutex::new(heapless::Vec::new()), + notifiers_listen_allowed: Signal::new(true), + } + } +} + /// Implements the `GattPeripheral` trait using the BlueZ GATT stack. #[derive(Clone)] pub struct BluerGattPeripheral(Arc); @@ -74,11 +87,23 @@ impl Default for BluerGattPeripheral { impl BluerGattPeripheral { /// Create a new instance. pub fn new(adapter_name: Option<&str>) -> Self { - Self(Arc::new(GattState { - adapter_name: adapter_name.map(|name| name.into()), - notifiers: IfMutex::new(heapless::Vec::new()), - notifiers_listen_allowed: Signal::new(true), - })) + Self(Arc::new(GattState::new(adapter_name))) + } + + /// Create an in-place initializer for the peripheral. + pub fn init(adapter_name: Option<&str>) -> impl Init { + let adapter_name = adapter_name.map(|name| name.to_owned()); + + // We can't (yet) use `pinned-init`'s `InPlaceInit` as it relies on unstable Rust features. + // Not so important specifically for this BlueZ implementation because it is STD-only + // and Linux-only, and there's plenty of stack space on Linux. + unsafe { + init_from_closure(move |slot: *mut BluerGattPeripheral| { + addr_of_mut!((*slot).0).write(Arc::new(GattState::new(adapter_name.as_deref()))); + + Ok(()) + }) + } } /// Runs the GATT peripheral service. diff --git a/rs-matter/src/transport/network/btp/session.rs b/rs-matter/src/transport/network/btp/session.rs index e09dd245..21a3a472 100644 --- a/rs-matter/src/transport/network/btp/session.rs +++ b/rs-matter/src/transport/network/btp/session.rs @@ -26,7 +26,8 @@ use crate::error::{Error, ErrorCode}; use crate::transport::network::btp::session::packet::{HandshakeReq, HandshakeResp}; use crate::transport::network::btp::{GATT_HEADER_SIZE, MAX_MTU, MIN_MTU}; use crate::transport::network::{BtAddr, MAX_RX_PACKET_SIZE}; -use crate::utils::{ringbuf::RingBuf, writebuf::WriteBuf}; +use crate::utils::init::{init, Init}; +use crate::utils::storage::{RingBuf, WriteBuf}; use self::packet::BtpHdr; @@ -172,18 +173,17 @@ struct RecvWindow { } impl RecvWindow { - /// Initialize a new receiving window with the provided window size. - #[inline(always)] - pub const fn new(window_size: u8) -> Self { - Self { - buf: RingBuf::new(), + /// Create an in-place initializer for a receiving window with the provided window size. + pub fn init(window_size: u8) -> impl Init { + init!(Self { + buf <- RingBuf::init(), buf_messages_ct: 0, level: window_size, ack_level: 0, ack_seq: 255, received_at: Instant::MAX, rem_msg_len: 0, - } + }) } /// Process an incoming BTP segment, updating the state of the window accordingly. @@ -375,21 +375,21 @@ pub struct Session { } impl Session { - /// Initialize a new BTP session with the provided address, version, MTU and window size. + /// Return an in-place initializer for a new BTP session with the provided address, version, + /// MTU and window size. /// /// Initializing a session is done based on the data that had arrived in the Handshake Request message, /// written by a remote peer on the `C1` characteristic. - #[inline(always)] - const fn new(address: BtAddr, version: u8, mtu: u16, window_size: u8) -> Self { - Self { + fn init(address: BtAddr, version: u8, mtu: u16, window_size: u8) -> impl Init { + init!(Self { address, state: SessionState::New, version, mtu, window_size, - recv_window: RecvWindow::new(window_size), + recv_window <- RecvWindow::init(window_size), send_window: SendWindow::new(window_size), - } + }) } /// Return the address of the remote peer. @@ -495,7 +495,7 @@ impl Session { address: BtAddr, data: &[u8], gatt_mtu: Option, - ) -> Result { + ) -> Result, Error> { let mut iter = data.iter(); let hdr = BtpHdr::from((&mut iter).copied())?; @@ -537,7 +537,7 @@ impl Session { info!("\n>>>>> (BTP IO) {address} [{hdr}]\nHANDSHAKE REQ {req:?}\nSelected version: {version}, MTU: {mtu}, window size: {window_size}"); - Ok(Self::new(address, version, mtu, window_size)) + Ok(Self::init(address, version, mtu, window_size)) } /// Process an incoming BTP segment of a regular data or ACK type, updating the state of the session accordingly. diff --git a/rs-matter/src/transport/network/btp/session/packet.rs b/rs-matter/src/transport/network/btp/session/packet.rs index 1cfef15a..eb681631 100644 --- a/rs-matter/src/transport/network/btp/session/packet.rs +++ b/rs-matter/src/transport/network/btp/session/packet.rs @@ -22,7 +22,7 @@ use bitflags::bitflags; use log::trace; use crate::error::{Error, ErrorCode}; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::WriteBuf; bitflags! { /// Models the flags in the BTP header. diff --git a/rs-matter/src/transport/network/btp/test.rs b/rs-matter/src/transport/network/btp/test.rs index 5c56284b..d9ef4a25 100644 --- a/rs-matter/src/transport/network/btp/test.rs +++ b/rs-matter/src/transport/network/btp/test.rs @@ -26,7 +26,8 @@ use alloc::{vec, vec::Vec}; use embassy_futures::block_on; use crate::secure_channel::spake2p::VerifierData; -use crate::utils::{rand::sys_rand, std_mutex::StdRawMutex}; +use crate::utils::rand::sys_rand; +use crate::utils::sync::blocking::raw::StdRawMutex; use super::*; diff --git a/rs-matter/src/transport/packet.rs b/rs-matter/src/transport/packet.rs index 8c4dd882..a1790373 100644 --- a/rs-matter/src/transport/packet.rs +++ b/rs-matter/src/transport/packet.rs @@ -19,11 +19,9 @@ use core::fmt; use log::trace; -use crate::{ - crypto::AEAD_MIC_LEN_BYTES, - error::Error, - utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, -}; +use crate::crypto::AEAD_MIC_LEN_BYTES; +use crate::error::Error; +use crate::utils::storage::{ParseBuf, WriteBuf}; use super::{ plain_hdr::{self, PlainHdr}, diff --git a/rs-matter/src/transport/plain_hdr.rs b/rs-matter/src/transport/plain_hdr.rs index e2b54812..609e88c3 100644 --- a/rs-matter/src/transport/plain_hdr.rs +++ b/rs-matter/src/transport/plain_hdr.rs @@ -18,8 +18,7 @@ use core::fmt; use crate::error::*; -use crate::utils::parsebuf::ParseBuf; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::{ParseBuf, WriteBuf}; use bitflags::bitflags; use log::trace; diff --git a/rs-matter/src/transport/proto_hdr.rs b/rs-matter/src/transport/proto_hdr.rs index 65b3bb02..b5641071 100644 --- a/rs-matter/src/transport/proto_hdr.rs +++ b/rs-matter/src/transport/proto_hdr.rs @@ -19,8 +19,7 @@ use bitflags::bitflags; use core::fmt; use crate::transport::plain_hdr; -use crate::utils::parsebuf::ParseBuf; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::{ParseBuf, WriteBuf}; use crate::{crypto, error::*}; use log::{trace, warn}; diff --git a/rs-matter/src/transport/session.rs b/rs-matter/src/transport/session.rs index 70110059..e0891268 100644 --- a/rs-matter/src/transport/session.rs +++ b/rs-matter/src/transport/session.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::cell::RefCell; use core::fmt; use core::num::NonZeroU8; use core::time::Duration; @@ -26,10 +25,11 @@ use crate::data_model::sdm::noc::NocData; use crate::error::*; use crate::transport::exchange::ExchangeId; use crate::transport::mrp::ReliableMessage; +use crate::utils::cell::RefCell; use crate::utils::epoch::Epoch; -use crate::utils::parsebuf::ParseBuf; +use crate::utils::init::{init, Init}; use crate::utils::rand::Rand; -use crate::utils::writebuf::WriteBuf; +use crate::utils::storage::{ParseBuf, WriteBuf}; use crate::Matter; use super::dedup::RxCtrState; @@ -538,16 +538,17 @@ pub struct SessionMgr { next_sess_unique_id: u32, next_sess_id: u16, next_exch_id: u16, - sessions: heapless::Vec, + sessions: crate::utils::storage::Vec, pub(crate) epoch: Epoch, pub(crate) rand: Rand, } impl SessionMgr { + /// Create a new session manager. #[inline(always)] pub const fn new(epoch: Epoch, rand: Rand) -> Self { Self { - sessions: heapless::Vec::new(), + sessions: crate::utils::storage::Vec::new(), next_sess_unique_id: 0, next_sess_id: 1, next_exch_id: 1, @@ -556,6 +557,18 @@ impl SessionMgr { } } + /// Create an in-place initializer for a new session manager. + pub fn init(epoch: Epoch, rand: Rand) -> impl Init { + init!(Self { + sessions <- crate::utils::storage::Vec::init(), + next_sess_unique_id: 0, + next_sess_id: 1, + next_exch_id: 1, + epoch, + rand, + }) + } + pub fn reset(&mut self) { self.sessions.clear(); self.next_sess_id = 1; diff --git a/rs-matter/src/utils/cell.rs b/rs-matter/src/utils/cell.rs new file mode 100644 index 00000000..015d8563 --- /dev/null +++ b/rs-matter/src/utils/cell.rs @@ -0,0 +1,1094 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! A modification of the Rust `RefCell` type which provides in-place initialization +//! via `RefCell::init`. +//! +//! NOTE: TEMPORARY and subject to removal once all Matter state is hidden behind +//! a `blmutex::Mutex` in future. + +#![allow(unexpected_cfgs)] +#![allow(clippy::should_implement_trait)] + +use core::cell::{Cell, UnsafeCell}; +use core::cmp::Ordering; +use core::fmt::{self, Debug, Display}; +use core::marker::PhantomData; +use core::mem; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; + +use super::init::{init, Init, UnsafeCellInit}; + +/// A mutable memory location with dynamically checked borrow rules +/// +/// See the [module-level documentation](self) for more. +pub struct RefCell { + borrow: Cell, + // Stores the location of the earliest currently active borrow. + // This gets updated whenever we go from having zero borrows + // to having a single borrow. When a borrow occurs, this gets included + // in the generated `BorrowError`/`BorrowMutError` + #[cfg(feature = "debug_refcell")] + borrowed_at: Cell>>, + _not_sync: PhantomData<*const ()>, + value: UnsafeCell, +} + +/// An error returned by [`RefCell::try_borrow`]. +#[non_exhaustive] +pub struct BorrowError { + #[cfg(feature = "debug_refcell")] + location: &'static crate::panic::Location<'static>, +} + +impl Debug for BorrowError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("BorrowError"); + + #[cfg(feature = "debug_refcell")] + builder.field("location", self.location); + + builder.finish() + } +} + +impl Display for BorrowError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt("already mutably borrowed", f) + } +} + +/// An error returned by [`RefCell::try_borrow_mut`]. +#[non_exhaustive] +pub struct BorrowMutError { + #[cfg(feature = "debug_refcell")] + location: &'static crate::panic::Location<'static>, +} + +impl Debug for BorrowMutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("BorrowMutError"); + + #[cfg(feature = "debug_refcell")] + builder.field("location", self.location); + + builder.finish() + } +} + +impl Display for BorrowMutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt("already borrowed", f) + } +} + +// This ensures the panicking code is outlined from `borrow_mut` for `RefCell`. +#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never))] +#[track_caller] +#[cold] +fn panic_already_borrowed(err: BorrowMutError) -> ! { + panic!("already borrowed: {:?}", err) +} + +// This ensures the panicking code is outlined from `borrow` for `RefCell`. +#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never))] +#[track_caller] +#[cold] +fn panic_already_mutably_borrowed(err: BorrowError) -> ! { + panic!("already mutably borrowed: {:?}", err) +} + +// Positive values represent the number of `Ref` active. Negative values +// represent the number of `RefMut` active. Multiple `RefMut`s can only be +// active at a time if they refer to distinct, nonoverlapping components of a +// `RefCell` (e.g., different ranges of a slice). +// +// `Ref` and `RefMut` are both two words in size, and so there will likely never +// be enough `Ref`s or `RefMut`s in existence to overflow half of the `usize` +// range. Thus, a `BorrowFlag` will probably never overflow or underflow. +// However, this is not a guarantee, as a pathological program could repeatedly +// create and then mem::forget `Ref`s or `RefMut`s. Thus, all code must +// explicitly check for overflow and underflow in order to avoid unsafety, or at +// least behave correctly in the event that overflow or underflow happens (e.g., +// see BorrowRef::new). +type BorrowFlag = isize; +const UNUSED: BorrowFlag = 0; + +#[inline(always)] +fn is_writing(x: BorrowFlag) -> bool { + x < UNUSED +} + +#[inline(always)] +fn is_reading(x: BorrowFlag) -> bool { + x > UNUSED +} + +impl RefCell { + /// Creates a new `RefCell` containing `value`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// ``` + #[inline] + pub const fn new(value: T) -> RefCell { + RefCell { + value: UnsafeCell::new(value), + borrow: Cell::new(UNUSED), + #[cfg(feature = "debug_refcell")] + borrowed_at: Cell::new(None), + _not_sync: PhantomData, + } + } + + /// Creates a new `RefCell` in-place initializer + /// by using the given value initializer. + pub fn init>(value: I) -> impl Init { + init!(Self { + value <- UnsafeCell::init(value), + borrow: Cell::new(UNUSED), + // #[cfg(feature = "debug_refcell")] + // borrowed_at: Cell::new(None), + _not_sync: PhantomData, + }) + } + + /// Consumes the `RefCell`, returning the wrapped value. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let five = c.into_inner(); + /// ``` + #[inline] + pub fn into_inner(self) -> T { + // Since this function takes `self` (the `RefCell`) by value, the + // compiler statically verifies that it is not currently borrowed. + self.value.into_inner() + } + + /// Replaces the wrapped value with a new one, returning the old value, + /// without deinitializing either one. + /// + /// This function corresponds to [`std::mem::replace`](../mem/fn.replace.html). + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// let cell = RefCell::new(5); + /// let old_value = cell.replace(6); + /// assert_eq!(old_value, 5); + /// assert_eq!(cell, RefCell::new(6)); + /// ``` + #[inline] + #[track_caller] + pub fn replace(&self, t: T) -> T { + mem::replace(&mut *self.borrow_mut(), t) + } + + /// Replaces the wrapped value with a new one computed from `f`, returning + /// the old value, without deinitializing either one. + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// let cell = RefCell::new(5); + /// let old_value = cell.replace_with(|&mut old| old + 1); + /// assert_eq!(old_value, 5); + /// assert_eq!(cell, RefCell::new(6)); + /// ``` + #[inline] + #[track_caller] + pub fn replace_with T>(&self, f: F) -> T { + let mut_borrow = &mut *self.borrow_mut(); + let replacement = f(mut_borrow); + mem::replace(mut_borrow, replacement) + } + + /// Swaps the wrapped value of `self` with the wrapped value of `other`, + /// without deinitializing either one. + /// + /// This function corresponds to [`std::mem::swap`](../mem/fn.swap.html). + /// + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently borrowed, or + /// if `self` and `other` point to the same `RefCell`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// let c = RefCell::new(5); + /// let d = RefCell::new(6); + /// c.swap(&d); + /// assert_eq!(c, RefCell::new(6)); + /// assert_eq!(d, RefCell::new(5)); + /// ``` + #[inline] + pub fn swap(&self, other: &Self) { + mem::swap(&mut *self.borrow_mut(), &mut *other.borrow_mut()) + } +} + +impl RefCell { + /// Immutably borrows the wrapped value. + /// + /// The borrow lasts until the returned `Ref` exits scope. Multiple + /// immutable borrows can be taken out at the same time. + /// + /// # Panics + /// + /// Panics if the value is currently mutably borrowed. For a non-panicking variant, use + /// [`try_borrow`](#method.try_borrow). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let borrowed_five = c.borrow(); + /// let borrowed_five2 = c.borrow(); + /// ``` + /// + /// An example of panic: + /// + /// ```should_panic + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let m = c.borrow_mut(); + /// let b = c.borrow(); // this causes a panic + /// ``` + #[inline] + #[track_caller] + pub fn borrow(&self) -> Ref<'_, T> { + match self.try_borrow() { + Ok(b) => b, + Err(err) => panic_already_mutably_borrowed(err), + } + } + + /// Immutably borrows the wrapped value, returning an error if the value is currently mutably + /// borrowed. + /// + /// The borrow lasts until the returned `Ref` exits scope. Multiple immutable borrows can be + /// taken out at the same time. + /// + /// This is the non-panicking variant of [`borrow`](#method.borrow). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// { + /// let m = c.borrow_mut(); + /// assert!(c.try_borrow().is_err()); + /// } + /// + /// { + /// let m = c.borrow(); + /// assert!(c.try_borrow().is_ok()); + /// } + /// ``` + #[inline] + #[cfg_attr(feature = "debug_refcell", track_caller)] + pub fn try_borrow(&self) -> Result, BorrowError> { + match BorrowRef::new(&self.borrow) { + Some(b) => { + #[cfg(feature = "debug_refcell")] + { + // `borrowed_at` is always the *first* active borrow + if b.borrow.get() == 1 { + self.borrowed_at.set(Some(crate::panic::Location::caller())); + } + } + + // SAFETY: `BorrowRef` ensures that there is only immutable access + // to the value while borrowed. + let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + Ok(Ref { value, borrow: b }) + } + None => Err(BorrowError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(feature = "debug_refcell")] + location: self.borrowed_at.get().unwrap(), + }), + } + } + + /// Mutably borrows the wrapped value. + /// + /// The borrow lasts until the returned `RefMut` or all `RefMut`s derived + /// from it exit scope. The value cannot be borrowed while this borrow is + /// active. + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. For a non-panicking variant, use + /// [`try_borrow_mut`](#method.try_borrow_mut). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new("hello".to_owned()); + /// + /// *c.borrow_mut() = "bonjour".to_owned(); + /// + /// assert_eq!(&*c.borrow(), "bonjour"); + /// ``` + /// + /// An example of panic: + /// + /// ```should_panic + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// let m = c.borrow(); + /// + /// let b = c.borrow_mut(); // this causes a panic + /// ``` + #[inline] + #[track_caller] + pub fn borrow_mut(&self) -> RefMut<'_, T> { + match self.try_borrow_mut() { + Ok(b) => b, + Err(err) => panic_already_borrowed(err), + } + } + + /// Mutably borrows the wrapped value, returning an error if the value is currently borrowed. + /// + /// The borrow lasts until the returned `RefMut` or all `RefMut`s derived + /// from it exit scope. The value cannot be borrowed while this borrow is + /// active. + /// + /// This is the non-panicking variant of [`borrow_mut`](#method.borrow_mut). + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// { + /// let m = c.borrow(); + /// assert!(c.try_borrow_mut().is_err()); + /// } + /// + /// assert!(c.try_borrow_mut().is_ok()); + /// ``` + #[inline] + #[cfg_attr(feature = "debug_refcell", track_caller)] + pub fn try_borrow_mut(&self) -> Result, BorrowMutError> { + match BorrowRefMut::new(&self.borrow) { + Some(b) => { + #[cfg(feature = "debug_refcell")] + { + self.borrowed_at.set(Some(crate::panic::Location::caller())); + } + + // SAFETY: `BorrowRefMut` guarantees unique access. + let value = unsafe { NonNull::new_unchecked(self.value.get()) }; + Ok(RefMut { + value, + borrow: b, + marker: PhantomData, + }) + } + None => Err(BorrowMutError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(feature = "debug_refcell")] + location: self.borrowed_at.get().unwrap(), + }), + } + } + + /// Returns a raw pointer to the underlying data in this cell. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// let ptr = c.as_ptr(); + /// ``` + #[inline] + //#[rustc_never_returns_null_ptr] + pub fn as_ptr(&self) -> *mut T { + self.value.get() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this method borrows `RefCell` mutably, it is statically guaranteed + /// that no borrows to the underlying data exist. The dynamic checks inherent + /// in [`borrow_mut`] and most other methods of `RefCell` are therefore + /// unnecessary. + /// + /// This method can only be called if `RefCell` can be mutably borrowed, + /// which in general is only the case directly after the `RefCell` has + /// been created. In these situations, skipping the aforementioned dynamic + /// borrowing checks may yield better ergonomics and runtime-performance. + /// + /// In most situations where `RefCell` is used, it can't be borrowed mutably. + /// Use [`borrow_mut`] to get mutable access to the underlying data then. + /// + /// [`borrow_mut`]: RefCell::borrow_mut() + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let mut c = RefCell::new(5); + /// *c.get_mut() += 1; + /// + /// assert_eq!(c, RefCell::new(6)); + /// ``` + #[inline] + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } + + /// Immutably borrows the wrapped value, returning an error if the value is + /// currently mutably borrowed. + /// + /// # Safety + /// + /// Unlike `RefCell::borrow`, this method is unsafe because it does not + /// return a `Ref`, thus leaving the borrow flag untouched. Mutably + /// borrowing the `RefCell` while the reference returned by this method + /// is alive is undefined behaviour. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// + /// { + /// let m = c.borrow_mut(); + /// assert!(unsafe { c.try_borrow_unguarded() }.is_err()); + /// } + /// + /// { + /// let m = c.borrow(); + /// assert!(unsafe { c.try_borrow_unguarded() }.is_ok()); + /// } + /// ``` + #[inline] + pub unsafe fn try_borrow_unguarded(&self) -> Result<&T, BorrowError> { + if !is_writing(self.borrow.get()) { + // SAFETY: We check that nobody is actively writing now, but it is + // the caller's responsibility to ensure that nobody writes until + // the returned reference is no longer in use. + // Also, `self.value.get()` refers to the value owned by `self` + // and is thus guaranteed to be valid for the lifetime of `self`. + Ok(unsafe { &*self.value.get() }) + } else { + Err(BorrowError { + // If a borrow occurred, then we must already have an outstanding borrow, + // so `borrowed_at` will be `Some` + #[cfg(feature = "debug_refcell")] + location: self.borrowed_at.get().unwrap(), + }) + } + } +} + +impl RefCell { + /// Takes the wrapped value, leaving `Default::default()` in its place. + /// + /// # Panics + /// + /// Panics if the value is currently borrowed. + /// + /// # Examples + /// + /// ``` + /// use std::cell::RefCell; + /// + /// let c = RefCell::new(5); + /// let five = c.take(); + /// + /// assert_eq!(five, 5); + /// assert_eq!(c.into_inner(), 0); + /// ``` + pub fn take(&self) -> T { + self.replace(Default::default()) + } +} + +unsafe impl Send for RefCell where T: Send {} + +impl Clone for RefCell { + /// # Panics + /// + /// Panics if the value is currently mutably borrowed. + #[inline] + #[track_caller] + fn clone(&self) -> RefCell { + RefCell::new(self.borrow().clone()) + } + + /// # Panics + /// + /// Panics if `other` is currently mutably borrowed. + #[inline] + #[track_caller] + fn clone_from(&mut self, other: &Self) { + self.get_mut().clone_from(&other.borrow()) + } +} + +impl Default for RefCell { + /// Creates a `RefCell`, with the `Default` value for T. + #[inline] + fn default() -> RefCell { + RefCell::new(Default::default()) + } +} + +impl PartialEq for RefCell { + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn eq(&self, other: &RefCell) -> bool { + *self.borrow() == *other.borrow() + } +} + +impl Eq for RefCell {} + +impl PartialOrd for RefCell { + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn partial_cmp(&self, other: &RefCell) -> Option { + self.borrow().partial_cmp(&*other.borrow()) + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn lt(&self, other: &RefCell) -> bool { + *self.borrow() < *other.borrow() + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn le(&self, other: &RefCell) -> bool { + *self.borrow() <= *other.borrow() + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn gt(&self, other: &RefCell) -> bool { + *self.borrow() > *other.borrow() + } + + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn ge(&self, other: &RefCell) -> bool { + *self.borrow() >= *other.borrow() + } +} + +impl Ord for RefCell { + /// # Panics + /// + /// Panics if the value in either `RefCell` is currently mutably borrowed. + #[inline] + fn cmp(&self, other: &RefCell) -> Ordering { + self.borrow().cmp(&*other.borrow()) + } +} + +impl From for RefCell { + /// Creates a new `RefCell` containing the given value. + fn from(t: T) -> RefCell { + RefCell::new(t) + } +} + +struct BorrowRef<'b> { + borrow: &'b Cell, +} + +impl<'b> BorrowRef<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option> { + let b = borrow.get().wrapping_add(1); + if !is_reading(b) { + // Incrementing borrow can result in a non-reading value (<= 0) in these cases: + // 1. It was < 0, i.e. there are writing borrows, so we can't allow a read borrow + // due to Rust's reference aliasing rules + // 2. It was isize::MAX (the max amount of reading borrows) and it overflowed + // into isize::MIN (the max amount of writing borrows) so we can't allow + // an additional read borrow because isize can't represent so many read borrows + // (this can only happen if you mem::forget more than a small constant amount of + // `Ref`s, which is not good practice) + None + } else { + // Incrementing borrow can result in a reading value (> 0) in these cases: + // 1. It was = 0, i.e. it wasn't borrowed, and we are taking the first read borrow + // 2. It was > 0 and < isize::MAX, i.e. there were read borrows, and isize + // is large enough to represent having one more read borrow + borrow.set(b); + Some(BorrowRef { borrow }) + } + } +} + +impl Drop for BorrowRef<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(is_reading(borrow)); + self.borrow.set(borrow - 1); + } +} + +impl Clone for BorrowRef<'_> { + #[inline] + fn clone(&self) -> Self { + // Since this Ref exists, we know the borrow flag + // is a reading borrow. + let borrow = self.borrow.get(); + debug_assert!(is_reading(borrow)); + // Prevent the borrow counter from overflowing into + // a writing borrow. + assert!(borrow != BorrowFlag::MAX); + self.borrow.set(borrow + 1); + BorrowRef { + borrow: self.borrow, + } + } +} + +/// Wraps a borrowed reference to a value in a `RefCell` box. +/// A wrapper type for an immutably borrowed value from a `RefCell`. +/// +/// See the [module-level documentation](self) for more. +// #[must_not_suspend = "holding a Ref across suspend points can cause BorrowErrors"] +// #[rustc_diagnostic_item = "RefCellRef"] +pub struct Ref<'b, T: ?Sized + 'b> { + // NB: we use a pointer instead of `&'b T` to avoid `noalias` violations, because a + // `Ref` argument doesn't hold immutability for its whole scope, only until it drops. + // `NonNull` is also covariant over `T`, just like we would have with `&T`. + value: NonNull, + borrow: BorrowRef<'b>, +} + +impl Deref for Ref<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} + +impl<'b, T: ?Sized> Ref<'b, T> { + /// Copies a `Ref`. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `Ref::clone(...)`. A `Clone` implementation or a method would interfere + /// with the widespread use of `r.borrow().clone()` to clone the contents of + /// a `RefCell`. + #[must_use] + #[inline] + pub fn clone(orig: &Ref<'b, T>) -> Ref<'b, T> { + Ref { + value: orig.value, + borrow: orig.borrow.clone(), + } + } + + /// Makes a new `Ref` for a component of the borrowed data. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as `Ref::map(...)`. + /// A method would interfere with methods of the same name on the contents + /// of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, Ref}; + /// + /// let c = RefCell::new((5, 'b')); + /// let b1: Ref<'_, (u32, char)> = c.borrow(); + /// let b2: Ref<'_, u32> = Ref::map(b1, |t| &t.0); + /// assert_eq!(*b2, 5) + /// ``` + #[inline] + pub fn map(orig: Ref<'b, T>, f: F) -> Ref<'b, U> + where + F: FnOnce(&T) -> &U, + { + Ref { + value: NonNull::from(f(&*orig)), + borrow: orig.borrow, + } + } + + /// Makes a new `Ref` for an optional component of the borrowed data. The + /// original guard is returned as an `Err(..)` if the closure returns + /// `None`. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `Ref::filter_map(...)`. A method would interfere with methods of the same + /// name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, Ref}; + /// + /// let c = RefCell::new(vec![1, 2, 3]); + /// let b1: Ref<'_, Vec> = c.borrow(); + /// let b2: Result, _> = Ref::filter_map(b1, |v| v.get(1)); + /// assert_eq!(*b2.unwrap(), 2); + /// ``` + #[inline] + pub fn filter_map(orig: Ref<'b, T>, f: F) -> Result, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + match f(&*orig) { + Some(value) => Ok(Ref { + value: NonNull::from(value), + borrow: orig.borrow, + }), + None => Err(orig), + } + } + + /// Splits a `Ref` into multiple `Ref`s for different components of the + /// borrowed data. + /// + /// The `RefCell` is already immutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `Ref::map_split(...)`. A method would interfere with methods of the same + /// name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{Ref, RefCell}; + /// + /// let cell = RefCell::new([1, 2, 3, 4]); + /// let borrow = cell.borrow(); + /// let (begin, end) = Ref::map_split(borrow, |slice| slice.split_at(2)); + /// assert_eq!(*begin, [1, 2]); + /// assert_eq!(*end, [3, 4]); + /// ``` + #[inline] + pub fn map_split(orig: Ref<'b, T>, f: F) -> (Ref<'b, U>, Ref<'b, V>) + where + F: FnOnce(&T) -> (&U, &V), + { + let (a, b) = f(&*orig); + let borrow = orig.borrow.clone(); + ( + Ref { + value: NonNull::from(a), + borrow, + }, + Ref { + value: NonNull::from(b), + borrow: orig.borrow, + }, + ) + } +} + +impl fmt::Display for Ref<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl<'b, T: ?Sized> RefMut<'b, T> { + /// Makes a new `RefMut` for a component of the borrowed data, e.g., an enum + /// variant. + /// + /// The `RefCell` is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `RefMut::map(...)`. A method would interfere with methods of the same + /// name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, RefMut}; + /// + /// let c = RefCell::new((5, 'b')); + /// { + /// let b1: RefMut<'_, (u32, char)> = c.borrow_mut(); + /// let mut b2: RefMut<'_, u32> = RefMut::map(b1, |t| &mut t.0); + /// assert_eq!(*b2, 5); + /// *b2 = 42; + /// } + /// assert_eq!(*c.borrow(), (42, 'b')); + /// ``` + #[inline] + pub fn map(mut orig: RefMut<'b, T>, f: F) -> RefMut<'b, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let value = NonNull::from(f(&mut *orig)); + RefMut { + value, + borrow: orig.borrow, + marker: PhantomData, + } + } + + /// Makes a new `RefMut` for an optional component of the borrowed data. The + /// original guard is returned as an `Err(..)` if the closure returns + /// `None`. + /// + /// The `RefCell` is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `RefMut::filter_map(...)`. A method would interfere with methods of the + /// same name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, RefMut}; + /// + /// let c = RefCell::new(vec![1, 2, 3]); + /// + /// { + /// let b1: RefMut<'_, Vec> = c.borrow_mut(); + /// let mut b2: Result, _> = RefMut::filter_map(b1, |v| v.get_mut(1)); + /// + /// if let Ok(mut b2) = b2 { + /// *b2 += 2; + /// } + /// } + /// + /// assert_eq!(*c.borrow(), vec![1, 4, 3]); + /// ``` + #[inline] + pub fn filter_map(mut orig: RefMut<'b, T>, f: F) -> Result, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + // SAFETY: function holds onto an exclusive reference for the duration + // of its call through `orig`, and the pointer is only de-referenced + // inside of the function call never allowing the exclusive reference to + // escape. + match f(&mut *orig) { + Some(value) => Ok(RefMut { + value: NonNull::from(value), + borrow: orig.borrow, + marker: PhantomData, + }), + None => Err(orig), + } + } + + /// Splits a `RefMut` into multiple `RefMut`s for different components of the + /// borrowed data. + /// + /// The underlying `RefCell` will remain mutably borrowed until both + /// returned `RefMut`s go out of scope. + /// + /// The `RefCell` is already mutably borrowed, so this cannot fail. + /// + /// This is an associated function that needs to be used as + /// `RefMut::map_split(...)`. A method would interfere with methods of the + /// same name on the contents of a `RefCell` used through `Deref`. + /// + /// # Examples + /// + /// ``` + /// use std::cell::{RefCell, RefMut}; + /// + /// let cell = RefCell::new([1, 2, 3, 4]); + /// let borrow = cell.borrow_mut(); + /// let (mut begin, mut end) = RefMut::map_split(borrow, |slice| slice.split_at_mut(2)); + /// assert_eq!(*begin, [1, 2]); + /// assert_eq!(*end, [3, 4]); + /// begin.copy_from_slice(&[4, 3]); + /// end.copy_from_slice(&[2, 1]); + /// ``` + #[inline] + pub fn map_split( + mut orig: RefMut<'b, T>, + f: F, + ) -> (RefMut<'b, U>, RefMut<'b, V>) + where + F: FnOnce(&mut T) -> (&mut U, &mut V), + { + let borrow = orig.borrow.clone(); + let (a, b) = f(&mut *orig); + ( + RefMut { + value: NonNull::from(a), + borrow, + marker: PhantomData, + }, + RefMut { + value: NonNull::from(b), + borrow: orig.borrow, + marker: PhantomData, + }, + ) + } +} + +struct BorrowRefMut<'b> { + borrow: &'b Cell, +} + +impl Drop for BorrowRefMut<'_> { + #[inline] + fn drop(&mut self) { + let borrow = self.borrow.get(); + debug_assert!(is_writing(borrow)); + self.borrow.set(borrow + 1); + } +} + +impl<'b> BorrowRefMut<'b> { + #[inline] + fn new(borrow: &'b Cell) -> Option> { + // NOTE: Unlike BorrowRefMut::clone, new is called to create the initial + // mutable reference, and so there must currently be no existing + // references. Thus, while clone increments the mutable refcount, here + // we explicitly only allow going from UNUSED to UNUSED - 1. + match borrow.get() { + UNUSED => { + borrow.set(UNUSED - 1); + Some(BorrowRefMut { borrow }) + } + _ => None, + } + } + + // Clones a `BorrowRefMut`. + // + // This is only valid if each `BorrowRefMut` is used to track a mutable + // reference to a distinct, nonoverlapping range of the original object. + // This isn't in a Clone impl so that code doesn't call this implicitly. + #[inline] + fn clone(&self) -> BorrowRefMut<'b> { + let borrow = self.borrow.get(); + debug_assert!(is_writing(borrow)); + // Prevent the borrow counter from underflowing. + assert!(borrow != BorrowFlag::MIN); + self.borrow.set(borrow - 1); + BorrowRefMut { + borrow: self.borrow, + } + } +} + +/// A wrapper type for a mutably borrowed value from a `RefCell`. +/// +/// See the [module-level documentation](self) for more. +// #[must_not_suspend = "holding a RefMut across suspend points can cause BorrowErrors"] +// #[rustc_diagnostic_item = "RefCellRefMut"] +pub struct RefMut<'b, T: ?Sized + 'b> { + // NB: we use a pointer instead of `&'b mut T` to avoid `noalias` violations, because a + // `RefMut` argument doesn't hold exclusivity for its whole scope, only until it drops. + value: NonNull, + borrow: BorrowRefMut<'b>, + // `NonNull` is covariant over `T`, so we need to reintroduce invariance. + marker: PhantomData<&'b mut T>, +} + +impl Deref for RefMut<'_, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_ref() } + } +} + +impl DerefMut for RefMut<'_, T> { + #[inline] + fn deref_mut(&mut self) -> &mut T { + // SAFETY: the value is accessible as long as we hold our borrow. + unsafe { self.value.as_mut() } + } +} + +impl fmt::Display for RefMut<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} diff --git a/rs-matter/src/utils/epoch.rs b/rs-matter/src/utils/epoch.rs index 630f8783..17ad4a2b 100644 --- a/rs-matter/src/utils/epoch.rs +++ b/rs-matter/src/utils/epoch.rs @@ -1,3 +1,20 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + use core::time::Duration; pub type Epoch = fn() -> Duration; diff --git a/rs-matter/src/utils/init.rs b/rs-matter/src/utils/init.rs new file mode 100644 index 00000000..e3bedfc1 --- /dev/null +++ b/rs-matter/src/utils/init.rs @@ -0,0 +1,91 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::convert::Infallible; +use core::{cell::UnsafeCell, mem::MaybeUninit}; + +/// Re-export `pinned-init` because its API is very unstable currently (0.0.x) +pub use pinned_init::*; + +/// An extension trait for converting `Init` to a fallible `Init`. +/// Useful when chaining an infallible initializer with a fallible chained initialization function. +pub trait IntoFallibleInit: Init { + /// Convert the infallible initializer to a fallible one. + fn into_fallible(self) -> impl Init { + unsafe { + init_from_closure(move |slot| { + Self::__init(self, slot).unwrap(); + + Ok(()) + }) + } + } +} + +impl IntoFallibleInit for I where I: Init {} + +/// An extension trait for re-setting an already instantiated `T` with the given initializer. +pub trait ApplyInit: Init { + fn apply(self, to: &mut T) -> Result<(), E> { + unsafe { Self::__init(self, to as *mut T) } + } +} + +impl ApplyInit for I where I: Init {} +/// An extension trait for retrofitting `UnsafeCell` with an initializer. +pub trait UnsafeCellInit { + /// Create a new in-place initializer for `UnsafeCell` + /// by using the given initializer for the value. + fn init>(value: I) -> impl Init; +} + +impl UnsafeCellInit for UnsafeCell { + fn init>(value: I) -> impl Init { + unsafe { + init_from_closure::<_, Infallible>(move |slot: *mut Self| { + // `slot` contains uninit memory, avoid creating a reference. + let slot: *mut T = slot as _; + + // Initialize the value + value.__init(slot).unwrap(); + + Ok(()) + }) + } + } +} + +/// An extension trait that allows safe initialization of +/// `MaybeUninit` memory. +pub trait InitMaybeUninit { + /// Initialize Self with the given in-place initializer. + fn init_with>(&mut self, init: I) -> &mut T { + self.try_init_with(init).unwrap() + } + + fn try_init_with, E>(&mut self, init: I) -> Result<&mut T, E>; +} + +impl InitMaybeUninit for MaybeUninit { + fn try_init_with, E>(&mut self, init: I) -> Result<&mut T, E> { + unsafe { + Init::::__init(init, self.as_mut_ptr())?; + + Ok(self.assume_init_mut()) + } + } +} diff --git a/rs-matter/src/utils/maybe.rs b/rs-matter/src/utils/maybe.rs new file mode 100644 index 00000000..b79cf206 --- /dev/null +++ b/rs-matter/src/utils/maybe.rs @@ -0,0 +1,215 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::fmt::Debug; +use core::hash::Hash; +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::ptr::addr_of_mut; + +use super::init; + +/// Represents a type similar in spirit to the built-in `Option` type. +/// Unlike `Option` however, `Maybe` _does_ have in-place initializer support. +/// +/// (In-place initializer support is impossible to provide for `Option` due to its +/// enum nature, and because it is not marked with `repr(transparent)`). +/// +/// `Maybe` is convertable to and from `Option` (via the `From` / `Into` traits), +/// however these conversions are not recommended when the wrapped value is large +/// which defeats the purpose of using `Maybe` in the first place. +/// +/// The canonical way to use `Maybe` with large values is to initialize it in-place with +/// one of the provided init constructors, and then use one of the `as_ref`, `as_mut`, +/// `as_deref` and `as_deref_mut` methods to access the wrapped value. +#[derive(Debug)] +pub struct Maybe { + some: bool, + value: MaybeUninit, + _tag: PhantomData, +} + +impl Maybe { + /// Create a new `Maybe` value from an `Option`. + /// + /// Note that when the wrapped value is large, it is recommended instead to use + /// `Maybe::init_none()` and `Maybe::init_some()` to create the `Maybe` value in-place. + pub fn new(value: Option) -> Self { + match value { + Some(v) => Self::some(v), + None => Self::none(), + } + } + + /// Create a new, empty `Maybe` value. + pub const fn none() -> Self { + Self { + some: false, + value: MaybeUninit::uninit(), + _tag: PhantomData, + } + } + + /// Create a new `Maybe` value with a wrapped value. + pub const fn some(value: T) -> Self { + Self { + some: true, + value: MaybeUninit::new(value), + _tag: PhantomData, + } + } + + /// Create an in-place initializer for a `Maybe` value that is empty. + pub fn init_none, E>() -> impl init::Init { + Self::init::(None) + } + + /// Create an in-place initializer for a `Maybe` value that is not empty + /// by initializing the wrapped value with the provided initializer. + pub fn init_some, E>(value: I) -> impl init::Init { + Self::init(Some(value)) + } + + /// Create an in-place initializer for a `Maybe` value that might or might + /// not be empty. + pub fn init, E>(value: Option) -> impl init::Init { + unsafe { + init::init_from_closure(move |slot: *mut Self| { + addr_of_mut!((*slot).some).write(value.is_some()); + + if let Some(value) = value { + value.__init(addr_of_mut!((*slot).value) as _)?; + } + + Ok(()) + }) + } + } + + /// Return a mutable reference to the wrapped value, if it exists. + pub fn as_mut(&mut self) -> Option<&mut T> { + if self.some { + Some(unsafe { self.value.assume_init_mut() }) + } else { + None + } + } + + /// Return a reference to the wrapped value, if it exists. + pub fn as_ref(&self) -> Option<&T> { + if self.some { + Some(unsafe { self.value.assume_init_ref() }) + } else { + None + } + } + + /// Derefs the wrapped value, if it exists. + pub fn as_deref(&self) -> Option<&T::Target> + where + T: Deref, + { + match self.as_ref() { + Some(t) => Some(t.deref()), + None => None, + } + } + + /// Derefs mutably the wrapped value, if it exists. + pub fn as_deref_mut(&mut self) -> Option<&mut T::Target> + where + T: DerefMut, + { + match self.as_mut() { + Some(t) => Some(t.deref_mut()), + None => None, + } + } + + /// Consume the `Maybe` value and return the wrapped value, if it exists. + /// + /// Note that this method is not efficient when the wrapped value is large + /// (might result in big stack memory usage due to moves), hence its usage + /// is not recommended when the wrapped value is large. + pub fn into_option(self) -> Option { + if self.some { + Some(unsafe { self.value.assume_init() }) + } else { + None + } + } + + /// Return whether the `Maybe` value is empty. + pub fn is_none(&self) -> bool { + !self.some + } + + /// Return whether the `Maybe` value is not empty. + pub fn is_some(&self) -> bool { + self.some + } +} + +impl Default for Maybe { + fn default() -> Self { + Self::none() + } +} + +impl From> for Maybe { + fn from(value: Option) -> Self { + Self::new(value) + } +} + +impl From> for Option { + fn from(value: Maybe) -> Self { + value.into_option() + } +} + +impl Clone for Maybe +where + T: Clone, +{ + fn clone(&self) -> Self { + Maybe::<_, G>::new(self.as_ref().cloned()) + } +} + +impl Copy for Maybe where T: Copy {} + +impl PartialEq for Maybe +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.as_ref() == other.as_ref() + } +} + +impl Eq for Maybe where T: Eq {} + +impl Hash for Maybe +where + T: Hash, +{ + fn hash(&self, state: &mut H) { + self.as_ref().hash(state) + } +} diff --git a/rs-matter/src/utils/mod.rs b/rs-matter/src/utils/mod.rs index b7c0136c..da766956 100644 --- a/rs-matter/src/utils/mod.rs +++ b/rs-matter/src/utils/mod.rs @@ -15,14 +15,11 @@ * limitations under the License. */ -pub mod buf; +pub mod cell; pub mod epoch; -pub mod ifmutex; -pub mod notification; -pub mod parsebuf; +pub mod init; +pub mod maybe; pub mod rand; -pub mod ringbuf; pub mod select; -pub mod signal; -pub mod std_mutex; -pub mod writebuf; +pub mod storage; +pub mod sync; diff --git a/rs-matter/src/utils/std_mutex.rs b/rs-matter/src/utils/storage.rs similarity index 50% rename from rs-matter/src/utils/std_mutex.rs rename to rs-matter/src/utils/storage.rs index 868e2339..a8c24904 100644 --- a/rs-matter/src/utils/std_mutex.rs +++ b/rs-matter/src/utils/storage.rs @@ -15,28 +15,14 @@ * limitations under the License. */ -#![cfg(feature = "std")] +pub use parsebuf::*; +pub use ringbuf::*; +pub use vec::*; +pub use writebuf::*; -use embassy_sync::blocking_mutex::raw::RawMutex; +pub mod pooled; -/// An `embassy-sync` `RawMutex` implementation using `std::sync::Mutex`. -/// TODO: Upstream into `embassy-sync` itself. -#[derive(Default)] -pub struct StdRawMutex(std::sync::Mutex<()>); - -impl StdRawMutex { - pub const fn new() -> Self { - Self(std::sync::Mutex::new(())) - } -} - -unsafe impl RawMutex for StdRawMutex { - #[allow(clippy::declare_interior_mutable_const)] - const INIT: Self = StdRawMutex(std::sync::Mutex::new(())); - - fn lock(&self, f: impl FnOnce() -> R) -> R { - let _guard = self.0.lock().unwrap(); - - f() - } -} +mod parsebuf; +mod ringbuf; +mod vec; +mod writebuf; diff --git a/rs-matter/src/utils/parsebuf.rs b/rs-matter/src/utils/storage/parsebuf.rs similarity index 99% rename from rs-matter/src/utils/parsebuf.rs rename to rs-matter/src/utils/storage/parsebuf.rs index d96c416f..3bc26685 100644 --- a/rs-matter/src/utils/parsebuf.rs +++ b/rs-matter/src/utils/storage/parsebuf.rs @@ -120,7 +120,7 @@ impl<'a> ParseBuf<'a> { #[cfg(test)] mod tests { - use crate::utils::parsebuf::*; + use crate::utils::storage::ParseBuf; #[test] fn test_parse_with_success() { diff --git a/rs-matter/src/utils/buf.rs b/rs-matter/src/utils/storage/pooled.rs similarity index 85% rename from rs-matter/src/utils/buf.rs rename to rs-matter/src/utils/storage/pooled.rs index 898f3bd5..a980902d 100644 --- a/rs-matter/src/utils/buf.rs +++ b/rs-matter/src/utils/storage/pooled.rs @@ -23,7 +23,8 @@ use embassy_futures::select::{select, Either}; use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_time::{Duration, Timer}; -use super::signal::Signal; +use crate::utils::init::{init, Init, UnsafeCellInit}; +use crate::utils::sync::Signal; /// A trait for getting access to a `&mut T` buffer, potentially awaiting until a buffer becomes available. pub trait BufferAccess @@ -59,7 +60,7 @@ where /// Accessing a buffer would fail when all buffers are still used elsewhere after a wait timeout expires. pub struct PooledBuffers { available: Signal, - pool: UnsafeCell>, + pool: UnsafeCell>, wait_timeout_ms: u32, } @@ -67,14 +68,30 @@ impl PooledBuffers where M: RawMutex, { + /// Create a new instance of `PooledBuffers`. + /// + /// `wait_timneout_ms` is the maximum time to wait for a buffer to become available + /// before returning `None`. #[inline(always)] pub const fn new(wait_timeout_ms: u32) -> Self { Self { available: Signal::new([true; N]), - pool: UnsafeCell::new(heapless::Vec::new()), + pool: UnsafeCell::new(crate::utils::storage::Vec::new()), wait_timeout_ms, } } + + /// Create an in-place initializer for `PooledBuffers`. + /// + /// `wait_timneout_ms` is the maximum time to wait for a buffer to become available + /// before returning `None`. + pub fn init(wait_timeout_ms: u32) -> impl Init { + init!(Self { + available: Signal::new([true; N]), + pool <- UnsafeCell::init(crate::utils::storage::Vec::init()), + wait_timeout_ms, + }) + } } impl BufferAccess for PooledBuffers diff --git a/rs-matter/src/utils/ringbuf.rs b/rs-matter/src/utils/storage/ringbuf.rs similarity index 87% rename from rs-matter/src/utils/ringbuf.rs rename to rs-matter/src/utils/storage/ringbuf.rs index a02f8264..2cefeb81 100644 --- a/rs-matter/src/utils/ringbuf.rs +++ b/rs-matter/src/utils/storage/ringbuf.rs @@ -17,13 +17,15 @@ use core::cmp::min; +use crate::utils::init::{init, Init}; + /// A ring buffer of a fixed capacity `N` using owned storage. #[derive(Debug)] pub struct RingBuf { - buf: heapless::Vec, + buf: crate::utils::storage::Vec, start: usize, end: usize, - empty: bool, + non_empty: bool, } impl Default for RingBuf { @@ -37,13 +39,23 @@ impl RingBuf { #[inline(always)] pub const fn new() -> Self { Self { - buf: heapless::Vec::new(), + buf: crate::utils::storage::Vec::new(), start: 0, end: 0, - empty: true, + non_empty: false, } } + /// Create an in-place initializer for the ring buffer. + pub fn init() -> impl Init { + init!(Self { + buf <- crate::utils::storage::Vec::init(), + start: 0, + end: 0, + non_empty: false, + }) + } + /// Push new data to the end of the buffer. /// If the data does not fit in the buffer, the oldest data is dropped to make room for the new one. /// @@ -62,7 +74,7 @@ impl RingBuf { offset += len; - if !self.empty && self.start >= self.end && self.start < self.end + len { + if self.non_empty && self.start >= self.end && self.start < self.end + len { // Dropping oldest data self.start = self.end + len; } @@ -71,7 +83,7 @@ impl RingBuf { self.wrap(); - self.empty = false; + self.non_empty = true; } self.len() @@ -88,7 +100,7 @@ impl RingBuf { self.buf[self.end] = data; - if !self.empty && self.start == self.end { + if self.non_empty && self.start == self.end { // Dropping oldest data self.start = self.end + 1; } @@ -97,7 +109,7 @@ impl RingBuf { self.wrap(); - self.empty = false; + self.non_empty = true; self.len() } @@ -121,7 +133,7 @@ impl RingBuf { pub fn pop(&mut self, out_buf: &mut [u8]) -> usize { let mut offset = 0; - while offset < out_buf.len() && !self.empty { + while offset < out_buf.len() && self.non_empty { let len = min( if self.start < self.end { self.end @@ -138,7 +150,7 @@ impl RingBuf { self.wrap(); if self.start == self.end { - self.empty = true + self.non_empty = false } offset += len; @@ -150,20 +162,20 @@ impl RingBuf { /// Return `true` when the buffer is full. #[inline(always)] pub fn is_full(&self) -> bool { - self.start == self.end && !self.empty + self.start == self.end && self.non_empty } /// Return `true` when the buffer is empty. #[inline(always)] pub fn is_empty(&self) -> bool { - self.empty + !self.non_empty } /// Return the current size of the data in the buffer. #[inline(always)] #[allow(unused)] pub fn len(&self) -> usize { - if self.empty { + if !self.non_empty { 0 } else if self.start < self.end { self.end - self.start @@ -184,7 +196,7 @@ impl RingBuf { pub fn clear(&mut self) { self.start = 0; self.end = 0; - self.empty = true; + self.non_empty = false; } #[inline(always)] diff --git a/rs-matter/src/utils/storage/vec.rs b/rs-matter/src/utils/storage/vec.rs new file mode 100644 index 00000000..49e85d80 --- /dev/null +++ b/rs-matter/src/utils/storage/vec.rs @@ -0,0 +1,1691 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! A modification of `heapless::Vec` that provides the following extra features: +//! - In-place initialization of the vec itself with `Vec::init() -> impl Init` +//! - In-place initialization of the vec members with `Vec::push_init(init: I) -> Result<(), ()>` + +#![allow(clippy::unnecessary_cast)] +#![allow(clippy::redundant_slicing)] +#![allow(clippy::result_unit_err)] +#![allow(clippy::should_implement_trait)] + +use core::{ + cmp::Ordering, + fmt, hash, + iter::FromIterator, + mem::MaybeUninit, + ops, + ptr::{self, addr_of_mut}, + slice, +}; + +use crate::utils::init::{init_from_closure, Init}; + +/// A fixed capacity [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html) +/// +/// # Examples +/// +/// ``` +/// use heapless::Vec; +/// +/// +/// // A vector with a fixed capacity of 8 elements allocated on the stack +/// let mut vec = Vec::<_, 8>::new(); +/// vec.push(1); +/// vec.push(2); +/// +/// assert_eq!(vec.len(), 2); +/// assert_eq!(vec[0], 1); +/// +/// assert_eq!(vec.pop(), Some(2)); +/// assert_eq!(vec.len(), 1); +/// +/// vec[0] = 7; +/// assert_eq!(vec[0], 7); +/// +/// vec.extend([1, 2, 3].iter().cloned()); +/// +/// for x in &vec { +/// println!("{}", x); +/// } +/// assert_eq!(*vec, [7, 1, 2, 3]); +/// ``` +pub struct Vec { + // NOTE order is important for optimizations. the `len` first layout lets the compiler optimize + // `new` to: reserve stack space and zero the first word. With the fields in the reverse order + // the compiler optimizes `new` to `memclr`-ing the *entire* stack space, including the `buffer` + // field which should be left uninitialized. Optimizations were last checked with Rust 1.60 + len: usize, + + buffer: [MaybeUninit; N], +} + +impl Vec { + const ELEM: MaybeUninit = MaybeUninit::uninit(); + const INIT: [MaybeUninit; N] = [Self::ELEM; N]; // important for optimization of `new` + + /// Constructs a new, empty vector with a fixed capacity of `N` + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// // allocate the vector on the stack + /// let mut x: Vec = Vec::new(); + /// + /// // allocate the vector in a static variable + /// static mut X: Vec = Vec::new(); + /// ``` + /// `Vec` `const` constructor; wrap the returned value in [`Vec`]. + pub const fn new() -> Self { + Self { + len: 0, + buffer: Self::INIT, + } + } + + /// Returns an in-place initializer for a new, empty vector. + pub fn init() -> impl Init { + unsafe { + init_from_closure(move |slot: *mut Self| { + addr_of_mut!((*slot).len).write(0); + + Ok(()) + }) + } + } + + /// Constructs a new vector with a fixed capacity of `N` and fills it + /// with the provided slice. + /// + /// This is equivalent to the following code: + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut v: Vec = Vec::new(); + /// v.extend_from_slice(&[1, 2, 3]).unwrap(); + /// ``` + #[inline] + pub fn from_slice(other: &[T]) -> Result + where + T: Clone, + { + let mut v = Vec::new(); + v.extend_from_slice(other)?; + Ok(v) + } + + /// Clones a vec into a new vec + pub(crate) fn clone(&self) -> Self + where + T: Clone, + { + let mut new = Self::new(); + // avoid `extend_from_slice` as that introduces a runtime check / panicking branch + for elem in self { + unsafe { + new.push_unchecked(elem.clone()); + } + } + new + } + + /// Returns a raw pointer to the vector’s buffer. + pub fn as_ptr(&self) -> *const T { + self.buffer.as_ptr() as *const T + } + + /// Returns a raw pointer to the vector’s buffer, which may be mutated through. + pub fn as_mut_ptr(&mut self) -> *mut T { + self.buffer.as_mut_ptr() as *mut T + } + + /// Extracts a slice containing the entire vector. + /// + /// Equivalent to `&s[..]`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// let buffer: Vec = Vec::from_slice(&[1, 2, 3, 5, 8]).unwrap(); + /// assert_eq!(buffer.as_slice(), &[1, 2, 3, 5, 8]); + /// ``` + pub fn as_slice(&self) -> &[T] { + // NOTE(unsafe) avoid bound checks in the slicing operation + // &buffer[..self.len] + unsafe { slice::from_raw_parts(self.buffer.as_ptr() as *const T, self.len) } + } + + /// Returns the contents of the vector as an array of length `M` if the length + /// of the vector is exactly `M`, otherwise returns `Err(self)`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// let buffer: Vec = Vec::from_slice(&[1, 2, 3, 5, 8]).unwrap(); + /// let array: [u8; 5] = buffer.into_array().unwrap(); + /// assert_eq!(array, [1, 2, 3, 5, 8]); + /// ``` + pub fn into_array(self) -> Result<[T; M], Self> { + if self.len() == M { + // This is how the unstable `MaybeUninit::array_assume_init` method does it + let array = unsafe { (&self.buffer as *const _ as *const [T; M]).read() }; + + // We don't want `self`'s destructor to be called because that would drop all the + // items in the array + core::mem::forget(self); + + Ok(array) + } else { + Err(self) + } + } + + /// Extracts a mutable slice containing the entire vector. + /// + /// Equivalent to `&mut s[..]`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// let mut buffer: Vec = Vec::from_slice(&[1, 2, 3, 5, 8]).unwrap(); + /// buffer[0] = 9; + /// assert_eq!(buffer.as_slice(), &[9, 2, 3, 5, 8]); + /// ``` + pub fn as_mut_slice(&mut self) -> &mut [T] { + // NOTE(unsafe) avoid bound checks in the slicing operation + // &mut buffer[..self.len] + unsafe { slice::from_raw_parts_mut(self.buffer.as_mut_ptr() as *mut T, self.len) } + } + + /// Returns the maximum number of elements the vector can hold. + pub const fn capacity(&self) -> usize { + N + } + + /// Clears the vector, removing all values. + pub fn clear(&mut self) { + self.truncate(0); + } + + /// Extends the vec from an iterator. + /// + /// # Panic + /// + /// Panics if the vec cannot hold all elements of the iterator. + pub fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + for elem in iter { + self.push(elem).ok().unwrap() + } + } + + /// Clones and appends all elements in a slice to the `Vec`. + /// + /// Iterates over the slice `other`, clones each element, and then appends + /// it to this `Vec`. The `other` vector is traversed in-order. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec = Vec::::new(); + /// vec.push(1).unwrap(); + /// vec.extend_from_slice(&[2, 3, 4]).unwrap(); + /// assert_eq!(*vec, [1, 2, 3, 4]); + /// ``` + pub fn extend_from_slice(&mut self, other: &[T]) -> Result<(), ()> + where + T: Clone, + { + if self.len + other.len() > self.capacity() { + // won't fit in the `Vec`; don't modify anything and return an error + Err(()) + } else { + for elem in other { + unsafe { + self.push_unchecked(elem.clone()); + } + } + Ok(()) + } + } + + /// Removes the last element from a vector and returns it, or `None` if it's empty + pub fn pop(&mut self) -> Option { + if self.len != 0 { + Some(unsafe { self.pop_unchecked() }) + } else { + None + } + } + + /// Appends an `item` to the back of the collection + /// + /// Returns back the `item` if the vector is full + pub fn push(&mut self, item: T) -> Result<(), T> { + if self.len < self.capacity() { + unsafe { self.push_unchecked(item) } + Ok(()) + } else { + Err(item) + } + } + + /// Appends an item with the provided item initializer - `init` + /// to the back of the collection + /// + /// Returns an error generated by `f` if the vector is full + pub fn push_init, E, F: FnOnce() -> E>( + &mut self, + init: I, + f: F, + ) -> Result<(), E> { + if self.len < self.capacity() { + self.push_init_unchecked(init) + } else { + Err(f()) + } + } + + /// Removes the last element from a vector and returns it + /// + /// # Safety + /// + /// This assumes the vec to have at least one element. + pub unsafe fn pop_unchecked(&mut self) -> T { + debug_assert!(!self.is_empty()); + + self.len -= 1; + (self.buffer.get_unchecked_mut(self.len).as_ptr() as *const T).read() + } + + /// Appends an `item` to the back of the collection + /// + /// # Safety + /// + /// This assumes the vec is not full. + pub unsafe fn push_unchecked(&mut self, item: T) { + // NOTE(ptr::write) the memory slot that we are about to write to is uninitialized. We + // use `ptr::write` to avoid running `T`'s destructor on the uninitialized memory + debug_assert!(!self.is_full()); + + *self.buffer.get_unchecked_mut(self.len) = MaybeUninit::new(item); + + self.len += 1; + } + + /// Appends an item with the provided item initializer - `init` + /// to the back of the collection + /// + /// Panics if the vec is full. + pub fn push_init_unchecked, E>(&mut self, init: I) -> Result<(), E> { + if self.is_full() { + panic!("Vec::push_init_unchecked: vec is full"); + } + + unsafe { + // NOTE(ptr::write) the memory slot that we are about to write to is uninitialized. We + // use `ptr::write` to avoid running `T`'s destructor on the uninitialized memory + let buffer: *mut T = self.buffer.as_mut_ptr().add(self.len) as _; + + init.__init(buffer)?; + } + + self.len += 1; + + Ok(()) + } + + /// Shortens the vector, keeping the first `len` elements and dropping the rest. + pub fn truncate(&mut self, len: usize) { + // This is safe because: + // + // * the slice passed to `drop_in_place` is valid; the `len > self.len` + // case avoids creating an invalid slice, and + // * the `len` of the vector is shrunk before calling `drop_in_place`, + // such that no value will be dropped twice in case `drop_in_place` + // were to panic once (if it panics twice, the program aborts). + unsafe { + // Note: It's intentional that this is `>` and not `>=`. + // Changing it to `>=` has negative performance + // implications in some cases. See rust-lang/rust#78884 for more. + if len > self.len { + return; + } + let remaining_len = self.len - len; + let s = ptr::slice_from_raw_parts_mut(self.as_mut_ptr().add(len), remaining_len); + self.len = len; + ptr::drop_in_place(s); + } + } + + /// Resizes the Vec in-place so that len is equal to new_len. + /// + /// If new_len is greater than len, the Vec is extended by the + /// difference, with each additional slot filled with value. If + /// new_len is less than len, the Vec is simply truncated. + /// + /// See also [`resize_default`](Self::resize_default). + pub fn resize(&mut self, new_len: usize, value: T) -> Result<(), ()> + where + T: Clone, + { + if new_len > self.capacity() { + return Err(()); + } + + if new_len > self.len { + while self.len < new_len { + self.push(value.clone()).ok(); + } + } else { + self.truncate(new_len); + } + + Ok(()) + } + + /// Resizes the `Vec` in-place so that `len` is equal to `new_len`. + /// + /// If `new_len` is greater than `len`, the `Vec` is extended by the + /// difference, with each additional slot filled with `Default::default()`. + /// If `new_len` is less than `len`, the `Vec` is simply truncated. + /// + /// See also [`resize`](Self::resize). + pub fn resize_default(&mut self, new_len: usize) -> Result<(), ()> + where + T: Clone + Default, + { + self.resize(new_len, T::default()) + } + + /// Forces the length of the vector to `new_len`. + /// + /// This is a low-level operation that maintains none of the normal + /// invariants of the type. Normally changing the length of a vector + /// is done using one of the safe operations instead, such as + /// [`truncate`], [`resize`], [`extend`], or [`clear`]. + /// + /// [`truncate`]: Self::truncate + /// [`resize`]: Self::resize + /// [`extend`]: core::iter::Extend + /// [`clear`]: Self::clear + /// + /// # Safety + /// + /// - `new_len` must be less than or equal to [`capacity()`]. + /// - The elements at `old_len..new_len` must be initialized. + /// + /// [`capacity()`]: Self::capacity + /// + /// # Examples + /// + /// This method can be useful for situations in which the vector + /// is serving as a buffer for other code, particularly over FFI: + /// + /// ```no_run + /// # #![allow(dead_code)] + /// use heapless::Vec; + /// + /// # // This is just a minimal skeleton for the doc example; + /// # // don't use this as a starting point for a real library. + /// # pub struct StreamWrapper { strm: *mut core::ffi::c_void } + /// # const Z_OK: i32 = 0; + /// # extern "C" { + /// # fn deflateGetDictionary( + /// # strm: *mut core::ffi::c_void, + /// # dictionary: *mut u8, + /// # dictLength: *mut usize, + /// # ) -> i32; + /// # } + /// # impl StreamWrapper { + /// pub fn get_dictionary(&self) -> Option> { + /// // Per the FFI method's docs, "32768 bytes is always enough". + /// let mut dict = Vec::new(); + /// let mut dict_length = 0; + /// // SAFETY: When `deflateGetDictionary` returns `Z_OK`, it holds that: + /// // 1. `dict_length` elements were initialized. + /// // 2. `dict_length` <= the capacity (32_768) + /// // which makes `set_len` safe to call. + /// unsafe { + /// // Make the FFI call... + /// let r = deflateGetDictionary(self.strm, dict.as_mut_ptr(), &mut dict_length); + /// if r == Z_OK { + /// // ...and update the length to what was initialized. + /// dict.set_len(dict_length); + /// Some(dict) + /// } else { + /// None + /// } + /// } + /// } + /// # } + /// ``` + /// + /// While the following example is sound, there is a memory leak since + /// the inner vectors were not freed prior to the `set_len` call: + /// + /// ``` + /// use core::iter::FromIterator; + /// use heapless::Vec; + /// + /// let mut vec = Vec::, 3>::from_iter( + /// [ + /// Vec::from_iter([1, 0, 0].iter().cloned()), + /// Vec::from_iter([0, 1, 0].iter().cloned()), + /// Vec::from_iter([0, 0, 1].iter().cloned()), + /// ] + /// .iter() + /// .cloned() + /// ); + /// // SAFETY: + /// // 1. `old_len..0` is empty so no elements need to be initialized. + /// // 2. `0 <= capacity` always holds whatever `capacity` is. + /// unsafe { + /// vec.set_len(0); + /// } + /// ``` + /// + /// Normally, here, one would use [`clear`] instead to correctly drop + /// the contents and thus not leak memory. + pub unsafe fn set_len(&mut self, new_len: usize) { + debug_assert!(new_len <= self.capacity()); + + self.len = new_len + } + + /// Removes an element from the vector and returns it. + /// + /// The removed element is replaced by the last element of the vector. + /// + /// This does not preserve ordering, but is O(1). + /// + /// # Panics + /// + /// Panics if `index` is out of bounds. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + ///// use heapless::consts::*; + /// + /// let mut v: Vec<_, 8> = Vec::new(); + /// v.push("foo").unwrap(); + /// v.push("bar").unwrap(); + /// v.push("baz").unwrap(); + /// v.push("qux").unwrap(); + /// + /// assert_eq!(v.swap_remove(1), "bar"); + /// assert_eq!(&*v, ["foo", "qux", "baz"]); + /// + /// assert_eq!(v.swap_remove(0), "foo"); + /// assert_eq!(&*v, ["baz", "qux"]); + /// ``` + pub fn swap_remove(&mut self, index: usize) -> T { + assert!(index < self.len); + unsafe { self.swap_remove_unchecked(index) } + } + + /// Removes an element from the vector and returns it. + /// + /// The removed element is replaced by the last element of the vector. + /// + /// This does not preserve ordering, but is O(1). + /// + /// # Safety + /// + /// Assumes `index` within bounds. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut v: Vec<_, 8> = Vec::new(); + /// v.push("foo").unwrap(); + /// v.push("bar").unwrap(); + /// v.push("baz").unwrap(); + /// v.push("qux").unwrap(); + /// + /// assert_eq!(unsafe { v.swap_remove_unchecked(1) }, "bar"); + /// assert_eq!(&*v, ["foo", "qux", "baz"]); + /// + /// assert_eq!(unsafe { v.swap_remove_unchecked(0) }, "foo"); + /// assert_eq!(&*v, ["baz", "qux"]); + /// ``` + pub unsafe fn swap_remove_unchecked(&mut self, index: usize) -> T { + let length = self.len(); + debug_assert!(index < length); + let value = ptr::read(self.as_ptr().add(index)); + let base_ptr = self.as_mut_ptr(); + ptr::copy(base_ptr.add(length - 1), base_ptr.add(index), 1); + self.len -= 1; + value + } + + /// Returns true if the vec is full + #[inline] + pub fn is_full(&self) -> bool { + self.len == self.capacity() + } + + /// Returns true if the vec is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns `true` if `needle` is a prefix of the Vec. + /// + /// Always returns `true` if `needle` is an empty slice. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let v: Vec<_, 8> = Vec::from_slice(b"abc").unwrap(); + /// assert_eq!(v.starts_with(b""), true); + /// assert_eq!(v.starts_with(b"ab"), true); + /// assert_eq!(v.starts_with(b"bc"), false); + /// ``` + #[inline] + pub fn starts_with(&self, needle: &[T]) -> bool + where + T: PartialEq, + { + let n = needle.len(); + self.len >= n && needle == &self[..n] + } + + /// Returns `true` if `needle` is a suffix of the Vec. + /// + /// Always returns `true` if `needle` is an empty slice. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let v: Vec<_, 8> = Vec::from_slice(b"abc").unwrap(); + /// assert_eq!(v.ends_with(b""), true); + /// assert_eq!(v.ends_with(b"ab"), false); + /// assert_eq!(v.ends_with(b"bc"), true); + /// ``` + #[inline] + pub fn ends_with(&self, needle: &[T]) -> bool + where + T: PartialEq, + { + let (v, n) = (self.len(), needle.len()); + v >= n && needle == &self[v - n..] + } + + /// Inserts an element at position `index` within the vector, shifting all + /// elements after it to the right. + /// + /// Returns back the `element` if the vector is full. + /// + /// # Panics + /// + /// Panics if `index > len`. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3]).unwrap(); + /// vec.insert(1, 4); + /// assert_eq!(vec, [1, 4, 2, 3]); + /// vec.insert(4, 5); + /// assert_eq!(vec, [1, 4, 2, 3, 5]); + /// ``` + pub fn insert(&mut self, index: usize, element: T) -> Result<(), T> { + let len = self.len(); + if index > len { + panic!( + "insertion index (is {}) should be <= len (is {})", + index, len + ); + } + + // check there's space for the new element + if self.is_full() { + return Err(element); + } + + unsafe { + // infallible + // The spot to put the new value + { + let p = self.as_mut_ptr().add(index); + // Shift everything over to make space. (Duplicating the + // `index`th element into two consecutive places.) + ptr::copy(p, p.offset(1), len - index); + // Write it in, overwriting the first copy of the `index`th + // element. + ptr::write(p, element); + } + self.set_len(len + 1); + } + + Ok(()) + } + + /// Removes and returns the element at position `index` within the vector, + /// shifting all elements after it to the left. + /// + /// Note: Because this shifts over the remaining elements, it has a + /// worst-case performance of *O*(*n*). If you don't need the order of + /// elements to be preserved, use [`swap_remove`] instead. If you'd like to + /// remove elements from the beginning of the `Vec`, consider using + /// [`Deque::pop_front`] instead. + /// + /// [`swap_remove`]: Vec::swap_remove + /// [`Deque::pop_front`]: crate::Deque::pop_front + /// + /// # Panics + /// + /// Panics if `index` is out of bounds. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut v: Vec<_, 8> = Vec::from_slice(&[1, 2, 3]).unwrap(); + /// assert_eq!(v.remove(1), 2); + /// assert_eq!(v, [1, 3]); + /// ``` + pub fn remove(&mut self, index: usize) -> T { + let len = self.len(); + if index >= len { + panic!("removal index (is {}) should be < len (is {})", index, len); + } + unsafe { + // infallible + let ret; + { + // the place we are taking from. + let ptr = self.as_mut_ptr().add(index); + // copy it out, unsafely having a copy of the value on + // the stack and in the vector at the same time. + ret = ptr::read(ptr); + + // Shift everything down to fill in that spot. + ptr::copy(ptr.offset(1), ptr, len - index - 1); + } + self.set_len(len - 1); + ret + } + } + + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `f(&e)` returns `false`. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3, 4]).unwrap(); + /// vec.retain(|&x| x % 2 == 0); + /// assert_eq!(vec, [2, 4]); + /// ``` + /// + /// Because the elements are visited exactly once in the original order, + /// external state may be used to decide which elements to keep. + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3, 4, 5]).unwrap(); + /// let keep = [false, true, true, false, true]; + /// let mut iter = keep.iter(); + /// vec.retain(|_| *iter.next().unwrap()); + /// assert_eq!(vec, [2, 3, 5]); + /// ``` + pub fn retain(&mut self, mut f: F) + where + F: FnMut(&T) -> bool, + { + self.retain_mut(|elem| f(elem)); + } + + /// Retains only the elements specified by the predicate, passing a mutable reference to it. + /// + /// In other words, remove all elements `e` such that `f(&mut e)` returns `false`. + /// This method operates in place, visiting each element exactly once in the + /// original order, and preserves the order of the retained elements. + /// + /// # Examples + /// + /// ``` + /// use heapless::Vec; + /// + /// let mut vec: Vec<_, 8> = Vec::from_slice(&[1, 2, 3, 4]).unwrap(); + /// vec.retain_mut(|x| if *x <= 3 { + /// *x += 1; + /// true + /// } else { + /// false + /// }); + /// assert_eq!(vec, [2, 3, 4]); + /// ``` + pub fn retain_mut(&mut self, mut f: F) + where + F: FnMut(&mut T) -> bool, + { + let original_len = self.len(); + // Avoid double drop if the drop guard is not executed, + // since we may make some holes during the process. + unsafe { self.set_len(0) }; + + // Vec: [Kept, Kept, Hole, Hole, Hole, Hole, Unchecked, Unchecked] + // |<- processed len ->| ^- next to check + // |<- deleted cnt ->| + // |<- original_len ->| + // Kept: Elements which predicate returns true on. + // Hole: Moved or dropped element slot. + // Unchecked: Unchecked valid elements. + // + // This drop guard will be invoked when predicate or `drop` of element panicked. + // It shifts unchecked elements to cover holes and `set_len` to the correct length. + // In cases when predicate and `drop` never panick, it will be optimized out. + struct BackshiftOnDrop<'a, T, const N: usize> { + v: &'a mut Vec, + processed_len: usize, + deleted_cnt: usize, + original_len: usize, + } + + impl Drop for BackshiftOnDrop<'_, T, N> { + fn drop(&mut self) { + if self.deleted_cnt > 0 { + // SAFETY: Trailing unchecked items must be valid since we never touch them. + unsafe { + ptr::copy( + self.v.as_ptr().add(self.processed_len), + self.v + .as_mut_ptr() + .add(self.processed_len - self.deleted_cnt), + self.original_len - self.processed_len, + ); + } + } + // SAFETY: After filling holes, all items are in contiguous memory. + unsafe { + self.v.set_len(self.original_len - self.deleted_cnt); + } + } + } + + let mut g = BackshiftOnDrop { + v: self, + processed_len: 0, + deleted_cnt: 0, + original_len, + }; + + fn process_loop( + original_len: usize, + f: &mut F, + g: &mut BackshiftOnDrop<'_, T, N>, + ) where + F: FnMut(&mut T) -> bool, + { + while g.processed_len != original_len { + let p = g.v.as_mut_ptr(); + // SAFETY: Unchecked element must be valid. + let cur = unsafe { &mut *p.add(g.processed_len) }; + if !f(cur) { + // Advance early to avoid double drop if `drop_in_place` panicked. + g.processed_len += 1; + g.deleted_cnt += 1; + // SAFETY: We never touch this element again after dropped. + unsafe { ptr::drop_in_place(cur) }; + // We already advanced the counter. + if DELETED { + continue; + } else { + break; + } + } + if DELETED { + // SAFETY: `deleted_cnt` > 0, so the hole slot must not overlap with current element. + // We use copy for move, and never touch this element again. + unsafe { + let hole_slot = p.add(g.processed_len - g.deleted_cnt); + ptr::copy_nonoverlapping(cur, hole_slot, 1); + } + } + g.processed_len += 1; + } + } + + // Stage 1: Nothing was deleted. + process_loop::(original_len, &mut f, &mut g); + + // Stage 2: Some elements were deleted. + process_loop::(original_len, &mut f, &mut g); + + // All item are processed. This can be optimized to `set_len` by LLVM. + drop(g); + } +} + +// Trait implementations + +impl Default for Vec { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for Vec +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + <[T] as fmt::Debug>::fmt(self, f) + } +} + +impl fmt::Write for Vec { + fn write_str(&mut self, s: &str) -> fmt::Result { + match self.extend_from_slice(s.as_bytes()) { + Ok(()) => Ok(()), + Err(_) => Err(fmt::Error), + } + } +} + +impl Drop for Vec { + fn drop(&mut self) { + // We drop each element used in the vector by turning into a &mut[T] + unsafe { + ptr::drop_in_place(self.as_mut_slice()); + } + } +} + +impl<'a, T: Clone, const N: usize> TryFrom<&'a [T]> for Vec { + type Error = (); + + fn try_from(slice: &'a [T]) -> Result { + Vec::from_slice(slice) + } +} + +impl Extend for Vec { + fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + self.extend(iter) + } +} + +impl<'a, T, const N: usize> Extend<&'a T> for Vec +where + T: 'a + Copy, +{ + fn extend(&mut self, iter: I) + where + I: IntoIterator, + { + self.extend(iter.into_iter().cloned()) + } +} + +impl hash::Hash for Vec +where + T: core::hash::Hash, +{ + fn hash(&self, state: &mut H) { + <[T] as hash::Hash>::hash(self, state) + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a Vec { + type Item = &'a T; + type IntoIter = slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T, const N: usize> IntoIterator for &'a mut Vec { + type Item = &'a mut T; + type IntoIter = slice::IterMut<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} + +impl FromIterator for Vec { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut vec = Vec::new(); + for i in iter { + vec.push(i).ok().expect("Vec::from_iter overflow"); + } + vec + } +} + +/// An iterator that moves out of an [`Vec`][`Vec`]. +/// +/// This struct is created by calling the `into_iter` method on [`Vec`][`Vec`]. +pub struct IntoIter { + vec: Vec, + next: usize, +} + +impl Iterator for IntoIter { + type Item = T; + fn next(&mut self) -> Option { + if self.next < self.vec.len() { + let item = unsafe { + (self.vec.buffer.get_unchecked_mut(self.next).as_ptr() as *const T).read() + }; + self.next += 1; + Some(item) + } else { + None + } + } +} + +impl Clone for IntoIter +where + T: Clone, +{ + fn clone(&self) -> Self { + let mut vec = Vec::new(); + + if self.next < self.vec.len() { + let s = unsafe { + slice::from_raw_parts( + (self.vec.buffer.as_ptr() as *const T).add(self.next), + self.vec.len() - self.next, + ) + }; + vec.extend_from_slice(s).ok(); + } + + Self { vec, next: 0 } + } +} + +impl Drop for IntoIter { + fn drop(&mut self) { + unsafe { + // Drop all the elements that have not been moved out of vec + ptr::drop_in_place(&mut self.vec.as_mut_slice()[self.next..]); + // Prevent dropping of other elements + self.vec.len = 0; + } + } +} + +impl IntoIterator for Vec { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { vec: self, next: 0 } + } +} + +impl PartialEq> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(self, &**other) + } +} + +// Vec == [B] +impl PartialEq<[B]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &[B]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// [B] == Vec +impl PartialEq> for [B] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == &[B] +impl PartialEq<&[B]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &&[B]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// &[B] == Vec +impl PartialEq> for &[B] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == &mut [B] +impl PartialEq<&mut [B]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &&mut [B]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// &mut [B] == Vec +impl PartialEq> for &mut [B] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == [B; M] +// Equality does not require equal capacity +impl PartialEq<[B; M]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &[B; M]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// [B; M] == Vec +// Equality does not require equal capacity +impl PartialEq> for [B; M] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Vec == &[B; M] +// Equality does not require equal capacity +impl PartialEq<&[B; M]> for Vec +where + A: PartialEq, +{ + fn eq(&self, other: &&[B; M]) -> bool { + <[A]>::eq(self, &other[..]) + } +} + +// &[B; M] == Vec +// Equality does not require equal capacity +impl PartialEq> for &[B; M] +where + A: PartialEq, +{ + fn eq(&self, other: &Vec) -> bool { + <[A]>::eq(other, &self[..]) + } +} + +// Implements Eq if underlying data is Eq +impl Eq for Vec where T: Eq {} + +impl PartialOrd> for Vec +where + T: PartialOrd, +{ + fn partial_cmp(&self, other: &Vec) -> Option { + PartialOrd::partial_cmp(&**self, &**other) + } +} + +impl Ord for Vec +where + T: Ord, +{ + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + Ord::cmp(&**self, &**other) + } +} + +impl ops::Deref for Vec { + type Target = [T]; + + fn deref(&self) -> &[T] { + self.as_slice() + } +} + +impl ops::DerefMut for Vec { + fn deref_mut(&mut self) -> &mut [T] { + self.as_mut_slice() + } +} + +impl AsRef> for Vec { + #[inline] + fn as_ref(&self) -> &Self { + self + } +} + +impl AsMut> for Vec { + #[inline] + fn as_mut(&mut self) -> &mut Self { + self + } +} + +impl AsRef<[T]> for Vec { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +impl AsMut<[T]> for Vec { + #[inline] + fn as_mut(&mut self) -> &mut [T] { + self + } +} + +impl Clone for Vec +where + T: Clone, +{ + fn clone(&self) -> Self { + self.clone() + } +} + +#[cfg(test)] +mod tests { + use core::fmt::Write; + + use super::Vec; + + macro_rules! droppable { + () => { + static COUNT: core::sync::atomic::AtomicI32 = core::sync::atomic::AtomicI32::new(0); + + #[derive(Eq, Ord, PartialEq, PartialOrd)] + struct Droppable(i32); + impl Droppable { + fn new() -> Self { + COUNT.fetch_add(1, core::sync::atomic::Ordering::Relaxed); + Droppable(Self::count()) + } + + fn count() -> i32 { + COUNT.load(core::sync::atomic::Ordering::Relaxed) + } + } + impl Drop for Droppable { + fn drop(&mut self) { + COUNT.fetch_sub(1, core::sync::atomic::Ordering::Relaxed); + } + } + }; + } + + #[test] + fn static_new() { + static mut _V: Vec = Vec::new(); + } + + #[test] + fn stack_new() { + let mut _v: Vec = Vec::new(); + } + + #[test] + fn is_full_empty() { + let mut v: Vec = Vec::new(); + + assert!(v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(!v.is_full()); + + v.push(1).unwrap(); + assert!(!v.is_empty()); + assert!(v.is_full()); + } + + #[test] + fn drop() { + droppable!(); + + { + let mut v: Vec = Vec::new(); + v.push(Droppable::new()).ok().unwrap(); + v.push(Droppable::new()).ok().unwrap(); + v.pop().unwrap(); + } + + assert_eq!(Droppable::count(), 0); + + { + let mut v: Vec = Vec::new(); + v.push(Droppable::new()).ok().unwrap(); + v.push(Droppable::new()).ok().unwrap(); + } + + assert_eq!(Droppable::count(), 0); + } + + #[test] + fn eq() { + let mut xs: Vec = Vec::new(); + let mut ys: Vec = Vec::new(); + + assert_eq!(xs, ys); + + xs.push(1).unwrap(); + ys.push(1).unwrap(); + + assert_eq!(xs, ys); + } + + #[test] + fn cmp() { + let mut xs: Vec = Vec::new(); + let mut ys: Vec = Vec::new(); + + assert_eq!(xs, ys); + + xs.push(1).unwrap(); + ys.push(2).unwrap(); + + assert!(xs < ys); + } + + #[test] + fn cmp_heterogenous_size() { + let mut xs: Vec = Vec::new(); + let mut ys: Vec = Vec::new(); + + assert_eq!(xs, ys); + + xs.push(1).unwrap(); + ys.push(2).unwrap(); + + assert!(xs < ys); + } + + #[test] + fn cmp_with_arrays_and_slices() { + let mut xs: Vec = Vec::new(); + xs.push(1).unwrap(); + + let array = [1]; + + assert_eq!(xs, array); + assert_eq!(array, xs); + + assert_eq!(xs, array.as_slice()); + assert_eq!(array.as_slice(), xs); + + assert_eq!(xs, &array); + assert_eq!(&array, xs); + + let longer_array = [1; 20]; + + assert_ne!(xs, longer_array); + assert_ne!(longer_array, xs); + } + + #[test] + fn full() { + let mut v: Vec = Vec::new(); + + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + assert!(v.push(4).is_err()); + } + + #[test] + fn iter() { + let mut v: Vec = Vec::new(); + + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + let mut items = v.iter(); + + assert_eq!(items.next(), Some(&0)); + assert_eq!(items.next(), Some(&1)); + assert_eq!(items.next(), Some(&2)); + assert_eq!(items.next(), Some(&3)); + assert_eq!(items.next(), None); + } + + #[test] + fn iter_mut() { + let mut v: Vec = Vec::new(); + + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + let mut items = v.iter_mut(); + + assert_eq!(items.next(), Some(&mut 0)); + assert_eq!(items.next(), Some(&mut 1)); + assert_eq!(items.next(), Some(&mut 2)); + assert_eq!(items.next(), Some(&mut 3)); + assert_eq!(items.next(), None); + } + + #[test] + fn collect_from_iter() { + let slice = &[1, 2, 3]; + let vec: Vec = slice.iter().cloned().collect(); + assert_eq!(&vec, slice); + } + + #[test] + #[should_panic] + fn collect_from_iter_overfull() { + let slice = &[1, 2, 3]; + let _vec = slice.iter().cloned().collect::>(); + } + + #[test] + fn iter_move() { + let mut v: Vec = Vec::new(); + v.push(0).unwrap(); + v.push(1).unwrap(); + v.push(2).unwrap(); + v.push(3).unwrap(); + + let mut items = v.into_iter(); + + assert_eq!(items.next(), Some(0)); + assert_eq!(items.next(), Some(1)); + assert_eq!(items.next(), Some(2)); + assert_eq!(items.next(), Some(3)); + assert_eq!(items.next(), None); + } + + #[test] + fn iter_move_drop() { + droppable!(); + + { + let mut vec: Vec = Vec::new(); + vec.push(Droppable::new()).ok().unwrap(); + vec.push(Droppable::new()).ok().unwrap(); + let mut items = vec.into_iter(); + // Move all + let _ = items.next(); + let _ = items.next(); + } + + assert_eq!(Droppable::count(), 0); + + { + let mut vec: Vec = Vec::new(); + vec.push(Droppable::new()).ok().unwrap(); + vec.push(Droppable::new()).ok().unwrap(); + let _items = vec.into_iter(); + // Move none + } + + assert_eq!(Droppable::count(), 0); + + { + let mut vec: Vec = Vec::new(); + vec.push(Droppable::new()).ok().unwrap(); + vec.push(Droppable::new()).ok().unwrap(); + let mut items = vec.into_iter(); + let _ = items.next(); // Move partly + } + + assert_eq!(Droppable::count(), 0); + } + + #[test] + fn push_and_pop() { + let mut v: Vec = Vec::new(); + assert_eq!(v.len(), 0); + + assert_eq!(v.pop(), None); + assert_eq!(v.len(), 0); + + v.push(0).unwrap(); + assert_eq!(v.len(), 1); + + assert_eq!(v.pop(), Some(0)); + assert_eq!(v.len(), 0); + + assert_eq!(v.pop(), None); + assert_eq!(v.len(), 0); + } + + #[test] + fn resize_size_limit() { + let mut v: Vec = Vec::new(); + + v.resize(0, 0).unwrap(); + v.resize(4, 0).unwrap(); + v.resize(5, 0).expect_err("full"); + } + + #[test] + fn resize_length_cases() { + let mut v: Vec = Vec::new(); + + assert_eq!(v.len(), 0); + + // Grow by 1 + v.resize(1, 0).unwrap(); + assert_eq!(v.len(), 1); + + // Grow by 2 + v.resize(3, 0).unwrap(); + assert_eq!(v.len(), 3); + + // Resize to current size + v.resize(3, 0).unwrap(); + assert_eq!(v.len(), 3); + + // Shrink by 1 + v.resize(2, 0).unwrap(); + assert_eq!(v.len(), 2); + + // Shrink by 2 + v.resize(0, 0).unwrap(); + assert_eq!(v.len(), 0); + } + + #[test] + fn resize_contents() { + let mut v: Vec = Vec::new(); + + // New entries take supplied value when growing + v.resize(1, 17).unwrap(); + assert_eq!(v[0], 17); + + // Old values aren't changed when growing + v.resize(2, 18).unwrap(); + assert_eq!(v[0], 17); + assert_eq!(v[1], 18); + + // Old values aren't changed when length unchanged + v.resize(2, 0).unwrap(); + assert_eq!(v[0], 17); + assert_eq!(v[1], 18); + + // Old values aren't changed when shrinking + v.resize(1, 0).unwrap(); + assert_eq!(v[0], 17); + } + + #[test] + fn resize_default() { + let mut v: Vec = Vec::new(); + + // resize_default is implemented using resize, so just check the + // correct value is being written. + v.resize_default(1).unwrap(); + assert_eq!(v[0], 0); + } + + #[test] + fn write() { + let mut v: Vec = Vec::new(); + write!(v, "{:x}", 1234).unwrap(); + assert_eq!(&v[..], b"4d2"); + } + + #[test] + fn extend_from_slice() { + let mut v: Vec = Vec::new(); + assert_eq!(v.len(), 0); + v.extend_from_slice(&[1, 2]).unwrap(); + assert_eq!(v.len(), 2); + assert_eq!(v.as_slice(), &[1, 2]); + v.extend_from_slice(&[3]).unwrap(); + assert_eq!(v.len(), 3); + assert_eq!(v.as_slice(), &[1, 2, 3]); + assert!(v.extend_from_slice(&[4, 5]).is_err()); + assert_eq!(v.len(), 3); + assert_eq!(v.as_slice(), &[1, 2, 3]); + } + + #[test] + fn from_slice() { + // Successful construction + let v: Vec = Vec::from_slice(&[1, 2, 3]).unwrap(); + assert_eq!(v.len(), 3); + assert_eq!(v.as_slice(), &[1, 2, 3]); + + // Slice too large + assert!(Vec::::from_slice(&[1, 2, 3]).is_err()); + } + + #[test] + fn starts_with() { + let v: Vec<_, 8> = Vec::from_slice(b"ab").unwrap(); + assert!(v.starts_with(&[])); + assert!(v.starts_with(b"")); + assert!(v.starts_with(b"a")); + assert!(v.starts_with(b"ab")); + assert!(!v.starts_with(b"abc")); + assert!(!v.starts_with(b"ba")); + assert!(!v.starts_with(b"b")); + } + + #[test] + fn ends_with() { + let v: Vec<_, 8> = Vec::from_slice(b"ab").unwrap(); + assert!(v.ends_with(&[])); + assert!(v.ends_with(b"")); + assert!(v.ends_with(b"b")); + assert!(v.ends_with(b"ab")); + assert!(!v.ends_with(b"abc")); + assert!(!v.ends_with(b"ba")); + assert!(!v.ends_with(b"a")); + } + + #[test] + fn zero_capacity() { + let mut v: Vec = Vec::new(); + // Validate capacity + assert_eq!(v.capacity(), 0); + + // Make sure there is no capacity + assert!(v.push(1).is_err()); + + // Validate length + assert_eq!(v.len(), 0); + + // Validate pop + assert_eq!(v.pop(), None); + + // Validate slice + assert_eq!(v.as_slice(), &[]); + + // Validate empty + assert!(v.is_empty()); + + // Validate full + assert!(v.is_full()); + } +} diff --git a/rs-matter/src/utils/writebuf.rs b/rs-matter/src/utils/storage/writebuf.rs similarity index 99% rename from rs-matter/src/utils/writebuf.rs rename to rs-matter/src/utils/storage/writebuf.rs index f3363cc4..0dcb873e 100644 --- a/rs-matter/src/utils/writebuf.rs +++ b/rs-matter/src/utils/storage/writebuf.rs @@ -213,7 +213,7 @@ impl<'a> WriteBuf<'a> { #[cfg(test)] mod tests { - use crate::utils::writebuf::*; + use crate::utils::storage::WriteBuf; #[test] fn test_append_le_with_success() { diff --git a/rs-matter/src/utils/sync.rs b/rs-matter/src/utils/sync.rs new file mode 100644 index 00000000..d1174467 --- /dev/null +++ b/rs-matter/src/utils/sync.rs @@ -0,0 +1,26 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +pub use mutex::*; +pub use notification::*; +pub use signal::*; + +pub mod blocking; + +mod mutex; +mod notification; +mod signal; diff --git a/rs-matter/src/utils/sync/blocking.rs b/rs-matter/src/utils/sync/blocking.rs new file mode 100644 index 00000000..6d599dbd --- /dev/null +++ b/rs-matter/src/utils/sync/blocking.rs @@ -0,0 +1,253 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +pub use mutex::*; + +mod mutex { + //! A variation of the `embassy-sync` blocking mutex that allows in-place initialization + //! of the mutex with `Mutex::init(..) -> impl Init`. + //! Check `embassy_sync::blocking_mutex::Mutex` for the original implementation. + + #![allow(clippy::should_implement_trait)] + + use core::cell::UnsafeCell; + + use embassy_sync::blocking_mutex::raw::{self, RawMutex}; + + use crate::utils::init::{init, Init, UnsafeCellInit}; + + /// Blocking mutex (not async) + /// + /// Provides a blocking mutual exclusion primitive backed by an implementation of [`raw::RawMutex`]. + /// + /// Which implementation you select depends on the context in which you're using the mutex, and you can choose which kind + /// of interior mutability fits your use case. + /// + /// Use [`CriticalSectionMutex`] when data can be shared between threads and interrupts. + /// + /// Use [`NoopMutex`] when data is only shared between tasks running on the same executor. + /// + /// Use [`ThreadModeMutex`] when data is shared between tasks running on the same executor but you want a global singleton. + /// + /// In all cases, the blocking mutex is intended to be short lived and not held across await points. + /// Use the async [`Mutex`](crate::mutex::Mutex) if you need a lock that is held across await points. + pub struct Mutex { + // NOTE: `raw` must be FIRST, so when using ThreadModeMutex the "can't drop in non-thread-mode" gets + // to run BEFORE dropping `data`. + raw: R, + data: UnsafeCell, + } + + unsafe impl Send for Mutex {} + unsafe impl Sync for Mutex {} + + impl Mutex { + /// Creates a new mutex in an unlocked state ready for use. + #[inline] + pub const fn new(val: T) -> Mutex { + Mutex { + raw: R::INIT, + data: UnsafeCell::new(val), + } + } + + /// Creates a mutex in-place initializer in an unlocked state ready for use. + pub fn init>(val: I) -> impl Init { + init!(Self { + raw: R::INIT, + data <- UnsafeCell::init(val), + }) + } + + /// Creates a critical section and grants temporary access to the protected data. + pub fn lock(&self, f: impl FnOnce(&T) -> U) -> U { + self.raw.lock(|| { + let ptr = self.data.get() as *const T; + let inner = unsafe { &*ptr }; + f(inner) + }) + } + } + + impl Mutex { + /// Creates a new mutex based on a pre-existing raw mutex. + /// + /// This allows creating a mutex in a constant context on stable Rust. + #[inline] + pub const fn const_new(raw_mutex: R, val: T) -> Mutex { + Mutex { + raw: raw_mutex, + data: UnsafeCell::new(val), + } + } + + /// Consumes this mutex, returning the underlying data. + #[inline] + pub fn into_inner(self) -> T { + self.data.into_inner() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place---the mutable borrow statically guarantees no locks exist. + #[inline] + pub fn get_mut(&mut self) -> &mut T { + unsafe { &mut *self.data.get() } + } + } + + /// A mutex that allows borrowing data across executors and interrupts. + /// + /// # Safety + /// + /// This mutex is safe to share between different executors and interrupts. + pub type CriticalSectionMutex = Mutex; + + /// A mutex that allows borrowing data in the context of a single executor. + /// + /// # Safety + /// + /// **This Mutex is only safe within a single executor.** + pub type NoopMutex = Mutex; + + impl Mutex { + /// Borrows the data for the duration of the critical section + pub fn borrow<'cs>(&'cs self, _cs: critical_section::CriticalSection<'cs>) -> &'cs T { + let ptr = self.data.get() as *const T; + unsafe { &*ptr } + } + } + + impl Mutex { + /// Borrows the data + pub fn borrow(&self) -> &T { + let ptr = self.data.get() as *const T; + unsafe { &*ptr } + } + } + + // // ThreadModeMutex does NOT use the generic mutex from above because it's special: + // // it's Send+Sync even if T: !Send. There's no way to do that without specialization (I think?). + // // + // // There's still a ThreadModeRawMutex for use with the generic Mutex (handy with Channel, for example), + // // but that will require T: Send even though it shouldn't be needed. + + // #[cfg(any(cortex_m, feature = "std"))] + // pub use thread_mode_mutex::*; + // #[cfg(any(cortex_m, feature = "std"))] + // mod thread_mode_mutex { + // use super::*; + + // /// A "mutex" that only allows borrowing from thread mode. + // /// + // /// # Safety + // /// + // /// **This Mutex is only safe on single-core systems.** + // /// + // /// On multi-core systems, a `ThreadModeMutex` **is not sufficient** to ensure exclusive access. + // pub struct ThreadModeMutex { + // inner: UnsafeCell, + // } + + // // NOTE: ThreadModeMutex only allows borrowing from one execution context ever: thread mode. + // // Therefore it cannot be used to send non-sendable stuff between execution contexts, so it can + // // be Send+Sync even if T is not Send (unlike CriticalSectionMutex) + // unsafe impl Sync for ThreadModeMutex {} + // unsafe impl Send for ThreadModeMutex {} + + // impl ThreadModeMutex { + // /// Creates a new mutex + // pub const fn new(value: T) -> Self { + // ThreadModeMutex { + // inner: UnsafeCell::new(value), + // } + // } + // } + + // impl ThreadModeMutex { + // /// Lock the `ThreadModeMutex`, granting access to the data. + // /// + // /// # Panics + // /// + // /// This will panic if not currently running in thread mode. + // pub fn lock(&self, f: impl FnOnce(&T) -> R) -> R { + // f(self.borrow()) + // } + + // /// Borrows the data + // /// + // /// # Panics + // /// + // /// This will panic if not currently running in thread mode. + // pub fn borrow(&self) -> &T { + // assert!( + // raw::in_thread_mode(), + // "ThreadModeMutex can only be borrowed from thread mode." + // ); + // unsafe { &*self.inner.get() } + // } + // } + + // impl Drop for ThreadModeMutex { + // fn drop(&mut self) { + // // Only allow dropping from thread mode. Dropping calls drop on the inner `T`, so + // // `drop` needs the same guarantees as `lock`. `ThreadModeMutex` is Send even if + // // T isn't, so without this check a user could create a ThreadModeMutex in thread mode, + // // send it to interrupt context and drop it there, which would "send" a T even if T is not Send. + // assert!( + // raw::in_thread_mode(), + // "ThreadModeMutex can only be dropped from thread mode." + // ); + + // // Drop of the inner `T` happens after this. + // } + // } + // } +} + +pub mod raw { + #[cfg(feature = "std")] + pub use std::*; + + #[cfg(feature = "std")] + mod std { + use embassy_sync::blocking_mutex::raw::RawMutex; + + /// An `embassy-sync` `RawMutex` implementation using `std::sync::Mutex`. + // TODO: Upstream into `embassy-sync` itself. + #[derive(Default)] + pub struct StdRawMutex(std::sync::Mutex<()>); + + impl StdRawMutex { + pub const fn new() -> Self { + Self(std::sync::Mutex::new(())) + } + } + + unsafe impl RawMutex for StdRawMutex { + #[allow(clippy::declare_interior_mutable_const)] + const INIT: Self = StdRawMutex(std::sync::Mutex::new(())); + + fn lock(&self, f: impl FnOnce() -> R) -> R { + let _guard = self.0.lock().unwrap(); + + f() + } + } + } +} diff --git a/rs-matter/src/utils/ifmutex.rs b/rs-matter/src/utils/sync/mutex.rs similarity index 94% rename from rs-matter/src/utils/ifmutex.rs rename to rs-matter/src/utils/sync/mutex.rs index e9dd8ef1..ab7590a0 100644 --- a/rs-matter/src/utils/ifmutex.rs +++ b/rs-matter/src/utils/sync/mutex.rs @@ -15,13 +15,17 @@ * limitations under the License. */ -//! A variation of the `embassy-sync` async mutex that only locks the mutex if a certain condition on the content of the data holds true. +//! A variation of the `embassy-sync` async mutex that only locks the mutex if a certain +//! condition on the content of the data holds true. //! Check `embassy_sync::Mutex` for the original unconditional implementation. + use core::cell::UnsafeCell; use core::ops::{Deref, DerefMut}; use embassy_sync::blocking_mutex::raw::RawMutex; +use crate::utils::init::{init, Init, UnsafeCellInit}; + use super::signal::Signal; /// Error returned by [`Mutex::try_lock`] @@ -55,6 +59,14 @@ where inner: UnsafeCell::new(value), } } + + /// Creates a mutex in-place initializer with the given value initializer. + pub fn init>(value: I) -> impl Init { + init!(Self { + state: Signal::::new(false), + inner <- UnsafeCell::init(value), + }) + } } impl IfMutex diff --git a/rs-matter/src/utils/notification.rs b/rs-matter/src/utils/sync/notification.rs similarity index 100% rename from rs-matter/src/utils/notification.rs rename to rs-matter/src/utils/sync/notification.rs diff --git a/rs-matter/src/utils/signal.rs b/rs-matter/src/utils/sync/signal.rs similarity index 74% rename from rs-matter/src/utils/signal.rs rename to rs-matter/src/utils/sync/signal.rs index a747f654..cc7fc1c8 100644 --- a/rs-matter/src/utils/signal.rs +++ b/rs-matter/src/utils/sync/signal.rs @@ -15,18 +15,38 @@ * limitations under the License. */ -use core::cell::RefCell; use core::future::poll_fn; use core::task::{Context, Poll}; -use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; +use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::waitqueue::WakerRegistration; +use crate::utils::cell::RefCell; +use crate::utils::init::{init, Init}; + +use super::blocking::Mutex; + struct State { state: S, waker: WakerRegistration, } +impl State { + const fn new(state: S) -> Self { + Self { + state, + waker: WakerRegistration::new(), + } + } + + fn init>(state: I) -> impl Init { + init!(Self { + state <- state, + waker: WakerRegistration::new(), + }) + } +} + /// `Signal` is an async synchonization primitive that can be viewed as a generalization of the `embassy_sync::Signal` primitive /// that takes callback closures. /// @@ -38,18 +58,26 @@ struct State { /// The generic nature of `Signal` allows for a wide range of use cases, including the implementation of: /// - the `Notification` primitive /// - the `IfMutex` primitive -pub struct Signal(Mutex>>); +pub struct Signal { + inner: Mutex>>, +} impl Signal where M: RawMutex, { - /// Crate a `Signal` with the given initial state `S`. + /// Create a `Signal` with the given initial state `S`. pub const fn new(state: S) -> Self { - Self(Mutex::new(RefCell::new(State { - state, - waker: WakerRegistration::new(), - }))) + Self { + inner: Mutex::new(RefCell::new(State::new(state))), + } + } + + /// Create a `Signal` in-place initializer with the given initial state initializer `I`. + pub fn init>(state: I) -> impl Init { + init!(Self { + inner <- Mutex::init(RefCell::init(State::init(state))), + }) } // Modify the state `S` and wake up the waiters if necessary. @@ -57,7 +85,7 @@ where where F: FnOnce(&mut S) -> (bool, R), { - self.0.lock(|s| { + self.inner.lock(|s| { let mut s = s.borrow_mut(); let (wake, result) = f(&mut s.state); @@ -83,7 +111,7 @@ where where F: FnOnce(&mut S) -> Option, { - self.0.lock(|s| { + self.inner.lock(|s| { let mut s = s.borrow_mut(); if let Some(result) = f(&mut s.state) { diff --git a/rs-matter/tests/common/attributes.rs b/rs-matter/tests/common/attributes.rs index eabef20c..55a5841b 100644 --- a/rs-matter/tests/common/attributes.rs +++ b/rs-matter/tests/common/attributes.rs @@ -18,7 +18,7 @@ use rs_matter::{ interaction_model::{messages::ib::AttrResp, messages::msg::ReportDataMsg}, tlv::{TLVElement, TLVList, TLVWriter, TagType, ToTLV}, - utils::writebuf::WriteBuf, + utils::storage::WriteBuf, }; /// Assert that the data received in the outbuf matches our expectations diff --git a/rs-matter/tests/common/im_engine.rs b/rs-matter/tests/common/im_engine.rs index 3a0b3b5d..50fd67ca 100644 --- a/rs-matter/tests/common/im_engine.rs +++ b/rs-matter/tests/common/im_engine.rs @@ -65,7 +65,7 @@ use rs_matter::{ }, session::{NocCatIds, ReservedSession, SessionMode}, }, - utils::{buf::PooledBuffers, select::Coalesce}, + utils::{select::Coalesce, storage::pooled::PooledBuffers}, Matter, MATTER_PORT, }; @@ -295,7 +295,7 @@ impl<'a> ImEngine<'a> { ADDR, SessionMode::Case { fab_idx: NonZeroU8::new(1).unwrap(), - cat_ids: cat_ids.clone(), + cat_ids: *cat_ids, }, None, None, diff --git a/rs-matter/tests/tlv_encoding.rs b/rs-matter/tests/tlv_encoding.rs index 5ce9664c..99a0b003 100644 --- a/rs-matter/tests/tlv_encoding.rs +++ b/rs-matter/tests/tlv_encoding.rs @@ -20,7 +20,7 @@ mod tlv_encoding_tests { use rs_matter::bitflags_tlv; use rs_matter::error::Error; use rs_matter::tlv::{get_root_node, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; - use rs_matter::utils::writebuf::WriteBuf; + use rs_matter::utils::storage::WriteBuf; #[derive(PartialEq, Debug, ToTLV, FromTLV)] struct SimpleStruct {