diff --git a/Cargo.lock b/Cargo.lock index eb53944590..86c69adc3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,6 +321,7 @@ dependencies = [ "azure_identity", "clap", "futures", + "moka", "reqwest", "serde", "serde_json", @@ -884,6 +885,15 @@ dependencies = [ "itertools", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1792,6 +1802,27 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "moka" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8261cd88c312e0004c1d51baad2980c66528dfdb2bee62003e643a4d8f86b077" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener", + "futures-util", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -2092,6 +2123,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + [[package]] name = "potential_utf" version = "0.1.2" @@ -2789,6 +2826,12 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tap" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index a729af313f..1041ae9882 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,6 +109,7 @@ getrandom = { version = "0.3" } gloo-timers = { version = "0.3" } hmac = { version = "0.12" } litemap = "0.7.4" +moka = { version = "0.12.11", features = ["future"] } openssl = { version = "0.10.72" } opentelemetry = { version = "0.30", features = ["trace"] } opentelemetry_sdk = "0.30" diff --git a/eng/dict/crates.txt b/eng/dict/crates.txt index 7af9fde24f..7e56bd271d 100644 --- a/eng/dict/crates.txt +++ b/eng/dict/crates.txt @@ -1,27 +1,32 @@ async-lock async-stream async-trait -azure_core +azure_canary +azure_canary_core azure_core azure_core_amqp -azure_core_amqp -azure_core_test +azure_core_macros +azure_core_opentelemetry azure_core_test azure_core_test_macros -azure_core_test_macros azure_data_cosmos azure_identity -azure_identity azure_messaging_eventhubs +azure_messaging_eventhubs_checkpointstore_blob +azure_messaging_servicebus +azure_security_keyvault_certificates azure_security_keyvault_keys azure_security_keyvault_secrets +azure_security_keyvault_test azure_storage_blob -azure_canary -azure_canary_core +azure_storage_blob_test +azure_storage_common +azure_storage_queue base64 bytes cargo_metadata clap +criterion dotenvy dyn-clone fe2o3-amqp @@ -32,11 +37,17 @@ fe2o3-amqp-types flate2 futures getrandom -gloo +gloo-timers hmac +http litemap -log +moka openssl +opentelemetry +opentelemetry-appender-tracing +opentelemetry-http +opentelemetry-stdout +opentelemetry_sdk pin-project proc-macro2 quick-xml @@ -44,6 +55,7 @@ quote rand rand_chacha reqwest +rust_decimal rustc_version serde serde_amqp @@ -52,7 +64,6 @@ serde_json serde_test serial_test sha2 -storage syn tar thiserror @@ -61,13 +72,12 @@ tokio tracing tracing-subscriber typespec -typespec -typespec_client_core typespec_client_core typespec_macros -typespec_macros ureq url uuid +wasm-bindgen-futures +wasm-bindgen-test zerofrom zip diff --git a/eng/dict/rust-custom.txt b/eng/dict/rust-custom.txt index b3b37a5653..dd2c735cc9 100644 --- a/eng/dict/rust-custom.txt +++ b/eng/dict/rust-custom.txt @@ -5,6 +5,7 @@ consts deque impls newtype +newtypes oneshot repr rustc diff --git a/eng/scripts/update-cratenames.rs b/eng/scripts/update-cratenames.rs index e3259dd6e9..5875b7f02b 100755 --- a/eng/scripts/update-cratenames.rs +++ b/eng/scripts/update-cratenames.rs @@ -9,7 +9,7 @@ toml = "0.8.10" --- use cargo_util_schemas::manifest::TomlManifest; -use std::{ffi::OsStr, fs, io::Write as _, path::PathBuf}; +use std::{collections::HashSet, fs, io::Write as _, path::PathBuf}; fn main() { let workspace_root = get_workspace_root(); @@ -29,25 +29,49 @@ fn main() { .dependencies .as_ref() .expect("expected workspace dependencies"); - let mut crate_names: Vec = dependencies.iter().map(|(name, _)| name.to_string()).collect(); + let mut crate_names: HashSet = dependencies + .iter() + .map(|(name, _)| name.to_string()) + .collect(); // Extract workspace members. - for relative_path in workspace_manifest + let members = workspace_manifest .workspace .as_ref() .expect("expected workspace") .members .as_ref() - .expect("expected workspace members") - .into_iter() { - let crate_name = PathBuf::from(relative_path) - .file_stem() - .and_then(OsStr::to_str) - .expect("expected crate name") - .to_string(); - crate_names.push(crate_name); + .expect("expected workspace members"); + + for relative_path in members.into_iter() { + let member_dir_name = PathBuf::from(relative_path) + .file_stem() + .map(|s| s.to_string_lossy().into_owned()) + .expect("member directory name"); + let member_manifest_path = workspace_root.join(relative_path).join("Cargo.toml"); + let member_manifest = fs::read_to_string(&member_manifest_path).unwrap(); + let member_manifest: TomlManifest = toml::from_str(&member_manifest).unwrap(); + + let sections = [ + member_manifest.dependencies, + member_manifest.dev_dependencies, + member_manifest.build_dependencies, + ]; + for section in sections.iter().flatten() { + for (name, _) in section.iter() { + crate_names.insert(name.to_string()); + } } + // Add the crate name for this member. + if let Some(name) = member_manifest.package.as_ref().map(|p| p.name.as_ref()) { + crate_names.insert(name.to_string()); + } else { + crate_names.insert(member_dir_name.to_string()); + } + } + + let mut crate_names: Vec = crate_names.into_iter().collect(); crate_names.sort(); let crate_names_path = workspace_root .join("eng/dict/crates.txt") diff --git a/sdk/cosmos/.dict.txt b/sdk/cosmos/.dict.txt index f6b563e481..990beab33e 100644 --- a/sdk/cosmos/.dict.txt +++ b/sdk/cosmos/.dict.txt @@ -5,6 +5,7 @@ euclidian pkranges sprocs udfs +substatus # Cosmos' docs all use "Autoscale" as a single word, rather than a compound "AutoScale" or "Auto Scale" autoscale diff --git a/sdk/cosmos/assets.json b/sdk/cosmos/assets.json index 0e743ffd6e..c6fbb68160 100644 --- a/sdk/cosmos/assets.json +++ b/sdk/cosmos/assets.json @@ -1,6 +1,6 @@ { "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "rust", - "Tag": "rust/azure_data_cosmos_a39b424a5b", + "Tag": "rust/azure_data_cosmos_69ad1e4995", "TagPrefix": "rust/azure_data_cosmos" } \ No newline at end of file diff --git a/sdk/cosmos/azure_data_cosmos/Cargo.toml b/sdk/cosmos/azure_data_cosmos/Cargo.toml index 22dce2f890..188dd52127 100644 --- a/sdk/cosmos/azure_data_cosmos/Cargo.toml +++ b/sdk/cosmos/azure_data_cosmos/Cargo.toml @@ -22,6 +22,7 @@ serde.workspace = true tracing.workspace = true typespec_client_core = { workspace = true, features = ["derive"] } url.workspace = true +moka.workspace = true [dev-dependencies] azure_identity.workspace = true diff --git a/sdk/cosmos/azure_data_cosmos/src/cache.rs b/sdk/cosmos/azure_data_cosmos/src/cache.rs new file mode 100644 index 0000000000..84082002dc --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/cache.rs @@ -0,0 +1,131 @@ +// cSpell:ignore smol + +use std::sync::Arc; + +use moka::future::Cache; + +use crate::{models::ContainerProperties, resource_context::ResourceLink, ResourceId}; + +#[derive(Debug)] +pub enum CacheError { + FetchError(Arc), +} + +impl From> for CacheError { + fn from(e: Arc) -> Self { + CacheError::FetchError(e) + } +} + +impl From for azure_core::Error { + fn from(e: CacheError) -> Self { + match e { + CacheError::FetchError(e) => { + let message = format!("error updating Container Metadata Cache: {}", e); + azure_core::Error::with_error(azure_core::error::ErrorKind::Other, e, message) + } + } + } +} + +impl std::fmt::Display for CacheError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CacheError::FetchError(e) => write!(f, "error fetching latest value: {}", e), + } + } +} + +impl std::error::Error for CacheError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CacheError::FetchError(e) => Some(&**e), + } + } +} + +/// A subset of container properties that are stable and suitable for caching. +pub(crate) struct ContainerMetadata { + pub resource_id: ResourceId, + pub container_link: ResourceLink, +} + +impl ContainerMetadata { + // We can't use From because we also want the container link. + pub fn from_properties( + properties: &ContainerProperties, + container_link: ResourceLink, + ) -> azure_core::Result { + let resource_id = properties + .system_properties + .resource_id + .clone() + .ok_or_else(|| { + azure_core::Error::new( + azure_core::error::ErrorKind::Other, + "container properties is missing expected value 'resource_id'", + ) + })?; + Ok(Self { + resource_id, + container_link, + }) + } +} + +/// A cache for container metadata, including properties and routing information. +/// +/// The cache can be cloned cheaply, and all clones share the same underlying cache data. +#[derive(Clone)] +pub struct ContainerMetadataCache { + /// Caches stable container metadata, mapping from container link to metadata. + container_properties_cache: Cache>, +} + +// TODO: Review this value. +// Cosmos has a backend limit of 500 databases and containers per account by default. +// This value affects when Moka will start evicting entries from the cache. +// It could probably be much lower without much impact, but we need to do the research to be sure. +const MAX_CACHE_CAPACITY: u64 = 500; + +impl ContainerMetadataCache { + /// Creates a new `ContainerMetadataCache` with default settings. + /// + /// Since the cache is designed to be shared, it is returned inside an `Arc`. + pub fn new() -> Self { + let container_properties_cache = Cache::new(MAX_CACHE_CAPACITY); + Self { + container_properties_cache, + } + } + + /// Unconditionally updates the cache with the provided container metadata. + pub async fn set_container_metadata(&self, metadata: ContainerMetadata) { + let metadata = Arc::new(metadata); + + self.container_properties_cache + .insert(metadata.container_link.clone(), metadata) + .await; + } + + /// Gets the container metadata from the cache, or initializes it using the provided async function if not present. + pub async fn get_container_metadata( + &self, + key: &ResourceLink, + init: impl std::future::Future>, + ) -> Result, CacheError> { + // TODO: Background refresh. We can do background refresh by storing an expiry time in the cache entry. + // Then, if the entry is stale, we can return the stale entry and spawn a background task to refresh it. + // There's a little trickiness here in that we can't directly spawn a task because that depends on a specific Async Runtime (tokio, smol, etc). + // The core SDK has an AsyncRuntime abstraction that we can use to spawn the task. + Ok(self + .container_properties_cache + .try_get_with_by_ref(key, async { init.await.map(Arc::new) }) + .await?) + } + + /// Removes the container metadata from the cache, forcing a refresh on the next access. + pub async fn remove_container_metadata(&self, key: &ResourceLink) { + self.container_properties_cache.invalidate(key).await; + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs index 7d28ee80f2..f3aa2a6618 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs @@ -1,12 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +use std::sync::Arc; + use crate::{ + cache::ContainerMetadata, + connection::CosmosConnection, constants, models::{ContainerProperties, PatchDocument, ThroughputProperties}, options::{QueryOptions, ReadContainerOptions}, - pipeline::CosmosPipeline, resource_context::{ResourceLink, ResourceType}, + status::{CosmosStatus, ErrorExt}, DeleteContainerOptions, FeedPager, ItemOptions, PartitionKey, Query, ReplaceContainerOptions, ThroughputOptions, }; @@ -14,7 +18,7 @@ use crate::{ use azure_core::http::{ request::{options::ContentType, Request}, response::Response, - Method, + Method, RawResponse, }; use serde::{de::DeserializeOwned, Serialize}; @@ -25,12 +29,12 @@ use serde::{de::DeserializeOwned, Serialize}; pub struct ContainerClient { link: ResourceLink, items_link: ResourceLink, - pipeline: CosmosPipeline, + connection: CosmosConnection, } impl ContainerClient { pub(crate) fn new( - pipeline: CosmosPipeline, + connection: CosmosConnection, database_link: &ResourceLink, container_id: &str, ) -> Self { @@ -42,7 +46,7 @@ impl ContainerClient { Self { link, items_link, - pipeline, + connection, } } @@ -66,12 +70,18 @@ impl ContainerClient { &self, options: Option>, ) -> azure_core::Result> { - let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); - let mut req = Request::new(url, Method::Get); - self.pipeline - .send(options.method_options.context, &mut req, self.link.clone()) - .await + let response: RawResponse = self.read_properties(options).await?.into(); + + // Read the properties and cache the stable metadata (things that don't change for the life of a container) + // TODO: Replace with `response.body().json()` when that becomes borrowing. + let properties = serde_json::from_slice::(response.body())?; + let metadata = ContainerMetadata::from_properties(&properties, self.link.clone())?; + self.connection + .cache() + .set_container_metadata(metadata) + .await; + + Ok(response.into()) } /// Updates the indexing policy of the container. @@ -110,11 +120,11 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Put); req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&properties)?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -130,17 +140,11 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result>> { let options = options.unwrap_or_default(); - - // We need to get the RID for the database. - let db = self.read(None).await?.into_body()?; - let resource_id = db - .system_properties - .resource_id - .expect("service should always return a '_rid' for a container"); - - self.pipeline - .read_throughput_offer(options.method_options.context, &resource_id) - .await + self.retry_if_cache_stale(|metadata| async move { + self.connection + .read_throughput_offer(options.method_options.context, &metadata.resource_id) + .await + }) } /// Replaces the container throughput properties. @@ -156,14 +160,9 @@ impl ContainerClient { let options = options.unwrap_or_default(); // We need to get the RID for the database. - let db = self.read(None).await?.into_body()?; - let resource_id = db - .system_properties - .resource_id - .expect("service should always return a '_rid' for a container"); - - self.pipeline - .replace_throughput_offer(options.method_options.context, &resource_id, throughput) + let resource_id = &self.metadata().await?.resource_id; + self.connection + .replace_throughput_offer(options.method_options.context, resource_id, throughput) .await } @@ -178,9 +177,9 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Delete); - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -257,13 +256,13 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.items_link); + let url = self.connection.url(&self.items_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&item)?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, @@ -346,13 +345,13 @@ impl ContainerClient { ) -> azure_core::Result> { let options = options.unwrap_or_default(); let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Put); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&item)?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -432,14 +431,14 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.items_link); + let url = self.connection.url(&self.items_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options)?; req.insert_header(constants::IS_UPSERT, "true"); req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&item)?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, @@ -490,11 +489,11 @@ impl ContainerClient { options.enable_content_response_on_write = true; let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Get); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -527,11 +526,11 @@ impl ContainerClient { ) -> azure_core::Result> { let options = options.unwrap_or_default(); let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Delete); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -600,14 +599,14 @@ impl ContainerClient { ) -> azure_core::Result> { let options = options.unwrap_or_default(); let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Patch); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&patch)?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -686,7 +685,7 @@ impl ContainerClient { if partition_key.is_empty() { if let Some(query_engine) = options.query_engine.take() { return crate::query::executor::QueryExecutor::new( - self.pipeline.clone(), + self.connection.clone(), self.link.clone(), query, options, @@ -696,8 +695,8 @@ impl ContainerClient { } } - let url = self.pipeline.url(&self.items_link); - self.pipeline.send_query_request( + let url = self.connection.url(&self.items_link); + self.connection.send_query_request( options.method_options.context, query, url, @@ -705,4 +704,53 @@ impl ContainerClient { |r| r.insert_headers(&partition_key), ) } + + async fn read_properties( + &self, + options: Option>, + ) -> azure_core::Result> { + let options = options.unwrap_or_default(); + let url = self.connection.url(&self.link); + let mut req = Request::new(url, Method::Get); + self.connection + .send(options.method_options.context, &mut req, self.link.clone()) + .await + } + + /// Executes the provided closure with cached container metadata, retrying once after refreshing the cache if the cache is stale. + /// + /// We only provide this mechanism for reading the metadata, to ensure we refresh the cache when necessary. + // TODO: If we need a way to write with cached metadata (since those operations may not be idempotent), we can add that later. + async fn retry_if_cache_stale(&self, f: F) -> azure_core::Result + where + F: Fn(Arc) -> Fut, + Fut: std::future::Future>, + { + async fn get_metadata( + client: &ContainerClient, + ) -> azure_core::Result> { + Ok(client + .connection + .cache() + .get_container_metadata(&client.link, async { + let properties = client.read_properties(None).await?.into_body()?; + ContainerMetadata::from_properties(&properties, client.link.clone()) + }) + .await?) + } + + let metadata = get_metadata(self).await?; + match f(metadata).await { + Err(err) if err.cosmos_status()? == Some(CosmosStatus::NAME_CACHE_IS_STALE) => { + // Invalidate the cache and try again + self.connection + .cache() + .remove_container_metadata(&self.link) + .await; + let metadata = get_metadata(self).await?; + f(metadata).await + } + x => x, + } + } } diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs index 209fadc284..418b0e82b2 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs @@ -3,8 +3,8 @@ use crate::{ clients::DatabaseClient, + connection::{AuthorizationPolicy, CosmosConnection}, models::DatabaseProperties, - pipeline::{AuthorizationPolicy, CosmosPipeline}, resource_context::{ResourceLink, ResourceType}, CosmosClientOptions, CreateDatabaseOptions, FeedPager, Query, QueryDatabasesOptions, }; @@ -23,10 +23,13 @@ use std::sync::Arc; use azure_core::credentials::Secret; /// Client for Azure Cosmos DB. -#[derive(Debug, Clone)] +/// +/// A [`CosmosClient`] can be safely shared between threads and is cheap to clone, as it holds most of the connection state in an [`Arc`]. +/// However, it's generally preferred to have a single `CosmosClient` per Cosmos account in your application, and share that between threads as needed. +#[derive(Clone)] pub struct CosmosClient { databases_link: ResourceLink, - pipeline: CosmosPipeline, + connection: CosmosConnection, } impl CosmosClient { @@ -55,7 +58,7 @@ impl CosmosClient { let options = options.unwrap_or_default(); Ok(Self { databases_link: ResourceLink::root(ResourceType::Databases), - pipeline: CosmosPipeline::new( + connection: CosmosConnection::new( endpoint.parse()?, AuthorizationPolicy::from_token_credential(credential), options.client_options, @@ -88,7 +91,7 @@ impl CosmosClient { let options = options.unwrap_or_default(); Ok(Self { databases_link: ResourceLink::root(ResourceType::Databases), - pipeline: CosmosPipeline::new( + connection: CosmosConnection::new( endpoint.parse()?, AuthorizationPolicy::from_shared_key(key), options.client_options, @@ -131,12 +134,12 @@ impl CosmosClient { /// # Arguments /// * `id` - The ID of the database. pub fn database_client(&self, id: &str) -> DatabaseClient { - DatabaseClient::new(self.pipeline.clone(), id) + DatabaseClient::new(self.connection.clone(), id) } /// Gets the endpoint of the database account this client is connected to. pub fn endpoint(&self) -> &Url { - &self.pipeline.endpoint + self.connection.endpoint() } /// Executes a query against databases in the account. @@ -168,9 +171,9 @@ impl CosmosClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.databases_link); + let url = self.connection.url(&self.databases_link); - self.pipeline.send_query_request( + self.connection.send_query_request( options.method_options.context, query.into(), url, @@ -198,13 +201,13 @@ impl CosmosClient { id: &'a str, } - let url = self.pipeline.url(&self.databases_link); + let url = self.connection.url(&self.databases_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options.throughput)?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&RequestBody { id })?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs index c9dc6889f0..8a7fd04bad 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs @@ -3,9 +3,9 @@ use crate::{ clients::ContainerClient, + connection::CosmosConnection, models::{ContainerProperties, DatabaseProperties, ThroughputProperties}, options::ReadDatabaseOptions, - pipeline::CosmosPipeline, resource_context::{ResourceLink, ResourceType}, CreateContainerOptions, DeleteDatabaseOptions, FeedPager, Query, QueryContainersOptions, ThroughputOptions, @@ -20,15 +20,16 @@ use azure_core::http::{ /// A client for working with a specific database in a Cosmos DB account. /// /// You can get a `DatabaseClient` by calling [`CosmosClient::database_client()`](crate::CosmosClient::database_client()). +#[derive(Clone)] pub struct DatabaseClient { link: ResourceLink, containers_link: ResourceLink, database_id: String, - pipeline: CosmosPipeline, + connection: CosmosConnection, } impl DatabaseClient { - pub(crate) fn new(pipeline: CosmosPipeline, database_id: &str) -> Self { + pub(crate) fn new(connection: CosmosConnection, database_id: &str) -> Self { let database_id = database_id.to_string(); let link = ResourceLink::root(ResourceType::Databases).item(&database_id); let containers_link = link.feed(ResourceType::Containers); @@ -37,7 +38,7 @@ impl DatabaseClient { link, containers_link, database_id, - pipeline, + connection, } } @@ -46,7 +47,7 @@ impl DatabaseClient { /// # Arguments /// * `name` - The name of the container. pub fn container_client(&self, name: &str) -> ContainerClient { - ContainerClient::new(self.pipeline.clone(), &self.link, name) + ContainerClient::new(self.connection.clone(), &self.link, name) } /// Returns the identifier of the Cosmos database. @@ -76,9 +77,9 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Get); - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -112,9 +113,9 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.containers_link); + let url = self.connection.url(&self.containers_link); - self.pipeline.send_query_request( + self.connection.send_query_request( options.method_options.context, query.into(), url, @@ -136,13 +137,13 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.containers_link); + let url = self.connection.url(&self.containers_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options.throughput)?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&properties)?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, @@ -162,9 +163,9 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Delete); - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -188,7 +189,7 @@ impl DatabaseClient { .resource_id .expect("service should always return a '_rid' for a database"); - self.pipeline + self.connection .read_throughput_offer(options.method_options.context, &resource_id) .await } @@ -212,7 +213,7 @@ impl DatabaseClient { .resource_id .expect("service should always return a '_rid' for a database"); - self.pipeline + self.connection .replace_throughput_offer(options.method_options.context, &resource_id, throughput) .await } diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs b/sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs similarity index 98% rename from sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs rename to sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs index 19cf51c2a0..7444323e52 100644 --- a/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs @@ -21,7 +21,7 @@ use azure_core::{ use std::sync::Arc; use tracing::trace; -use crate::{pipeline::signature_target::SignatureTarget, resource_context::ResourceLink}; +use crate::{connection::signature_target::SignatureTarget, resource_context::ResourceLink}; use crate::utils::url_encode; @@ -153,7 +153,7 @@ mod tests { use url::Url; use crate::{ - pipeline::{ + connection::{ authorization_policy::{generate_authorization, scope_from_url, Credential}, signature_target::SignatureTarget, }, diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs similarity index 89% rename from sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs rename to sdk/cosmos/azure_data_cosmos/src/connection/mod.rs index 0f84843a30..1debfe2faa 100644 --- a/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs @@ -18,27 +18,33 @@ use serde::de::DeserializeOwned; use url::Url; use crate::{ + cache::ContainerMetadataCache, constants, models::ThroughputProperties, resource_context::{ResourceLink, ResourceType}, - FeedPage, FeedPager, Query, + FeedPage, FeedPager, Query, ResourceId, }; -/// Newtype that wraps an Azure Core pipeline to provide a Cosmos-specific pipeline which configures our authorization policy and enforces that a [`ResourceType`] is set on the context. -#[derive(Debug, Clone)] -pub struct CosmosPipeline { - pub endpoint: Url, +/// Represents a connection to a specific Cosmos account. +/// +/// The [`CosmosConnection`] holds all the shared state for a connection to a Cosmos DB account. +/// A connection is cheap to clone, and all clones share the same underlying HTTP pipeline and metadata cache. +#[derive(Clone)] +pub struct CosmosConnection { + endpoint: Url, + cache: ContainerMetadataCache, pipeline: azure_core::http::Pipeline, } -impl CosmosPipeline { +impl CosmosConnection { pub fn new( endpoint: Url, auth_policy: AuthorizationPolicy, client_options: ClientOptions, ) -> Self { - CosmosPipeline { + CosmosConnection { endpoint, + cache: ContainerMetadataCache::new(), pipeline: azure_core::http::Pipeline::new( option_env!("CARGO_PKG_NAME"), option_env!("CARGO_PKG_VERSION"), @@ -50,6 +56,14 @@ impl CosmosPipeline { } } + pub fn endpoint(&self) -> &Url { + &self.endpoint + } + + pub fn cache(&self) -> &ContainerMetadataCache { + &self.cache + } + /// Creates a [`Url`] out of the provided [`ResourceLink`] /// /// This is a little backwards, ideally we'd accept [`ResourceLink`] in the [`CosmosPipeline::send`] method, @@ -124,7 +138,7 @@ impl CosmosPipeline { pub async fn read_throughput_offer( &self, context: Context<'_>, - resource_id: &str, + resource_id: &ResourceId, ) -> azure_core::Result>> { // We only have to into_owned here in order to call send_query_request below, // since it returns `Pager` which must own it's data. @@ -164,7 +178,7 @@ impl CosmosPipeline { pub async fn replace_throughput_offer( &self, context: Context<'_>, - resource_id: &str, + resource_id: &ResourceId, throughput: ThroughputProperties, ) -> azure_core::Result> { let response = self diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs b/sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs similarity index 98% rename from sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs rename to sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs index 6db3089903..a3ad1558da 100644 --- a/sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs @@ -70,7 +70,7 @@ mod tests { use azure_core::{http::Method, time}; use crate::{ - pipeline::signature_target::SignatureTarget, + connection::signature_target::SignatureTarget, resource_context::{ResourceLink, ResourceType}, }; diff --git a/sdk/cosmos/azure_data_cosmos/src/lib.rs b/sdk/cosmos/azure_data_cosmos/src/lib.rs index dcf5486a41..49c195840f 100644 --- a/sdk/cosmos/azure_data_cosmos/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos/src/lib.rs @@ -4,20 +4,23 @@ #![doc = include_str!("../README.md")] #![cfg_attr(docsrs, feature(doc_cfg))] +mod cache; pub mod clients; +mod connection; mod connection_string; pub mod constants; mod feed; +mod location_cache; +pub mod models; mod options; mod partition_key; -pub(crate) mod pipeline; pub mod query; -pub(crate) mod resource_context; -pub(crate) mod utils; +mod resource_context; +mod types; +mod utils; +mod status; -pub mod models; - -mod location_cache; +pub use types::ResourceId; #[doc(inline)] pub use clients::CosmosClient; diff --git a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs index dbd3355383..dd0c6222c1 100644 --- a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs @@ -18,6 +18,8 @@ pub use partition_key_definition::*; pub use patch_operations::*; pub use throughput_properties::*; +use crate::ResourceId; + fn deserialize_cosmos_timestamp<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, @@ -82,7 +84,7 @@ pub struct SystemProperties { // Some APIs do expect the "_rid" to be provided (Replace Offer, for example), so we do want to serialize it if it's provided. #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "_rid")] - pub resource_id: Option, + pub resource_id: Option, /// A [`OffsetDateTime`] representing the last modified time of the resource. #[serde(default)] diff --git a/sdk/cosmos/azure_data_cosmos/src/query/executor.rs b/sdk/cosmos/azure_data_cosmos/src/query/executor.rs index f0e93b326d..b465faf3cb 100644 --- a/sdk/cosmos/azure_data_cosmos/src/query/executor.rs +++ b/sdk/cosmos/azure_data_cosmos/src/query/executor.rs @@ -2,15 +2,15 @@ use azure_core::http::{headers::Headers, Context, Method, RawResponse, Request}; use serde::de::DeserializeOwned; use crate::{ + connection::CosmosConnection, constants, - pipeline::{self, CosmosPipeline}, query::{OwnedQueryPipeline, QueryEngineRef, QueryResult}, resource_context::{ResourceLink, ResourceType}, FeedPage, FeedPager, Query, QueryOptions, }; pub struct QueryExecutor { - http_pipeline: CosmosPipeline, + connection: CosmosConnection, container_link: ResourceLink, items_link: ResourceLink, context: Context<'static>, @@ -29,7 +29,7 @@ pub struct QueryExecutor { impl QueryExecutor { pub fn new( - http_pipeline: CosmosPipeline, + connection: CosmosConnection, container_link: ResourceLink, query: Query, options: QueryOptions<'_>, @@ -38,7 +38,7 @@ impl QueryExecutor { let items_link = container_link.feed(ResourceType::Items); let context = options.method_options.context.into_owned(); Ok(Self { - http_pipeline, + connection, container_link, items_link, context, @@ -77,7 +77,7 @@ impl QueryExecutor { None => { // Initialize the pipeline. let query_plan = get_query_plan( - &self.http_pipeline, + &self.connection, &self.items_link, self.context.to_borrowed(), &self.query, @@ -86,7 +86,7 @@ impl QueryExecutor { .await? .into_body(); let pkranges = get_pkranges( - &self.http_pipeline, + &self.connection, &self.container_link, self.context.to_borrowed(), ) @@ -97,8 +97,8 @@ impl QueryExecutor { self.query_engine .create_pipeline(&self.query.text, &query_plan, &pkranges)?; self.query.text = pipeline.query().into(); - self.base_request = Some(crate::pipeline::create_base_query_request( - self.http_pipeline.url(&self.items_link), + self.base_request = Some(crate::connection::create_base_query_request( + self.connection.url(&self.items_link), &self.query, )?); self.pipeline = Some(pipeline); @@ -139,7 +139,7 @@ impl QueryExecutor { } let resp = self - .http_pipeline + .connection .send_raw( self.context.to_borrowed(), &mut query_request, @@ -172,14 +172,14 @@ impl QueryExecutor { // This isn't an inherent method on QueryExecutor because that would force the whole executor to be Sync, which would force the pipeline to be Sync. #[tracing::instrument(skip_all)] async fn get_query_plan( - http_pipeline: &CosmosPipeline, + http_pipeline: &CosmosConnection, items_link: &ResourceLink, context: Context<'_>, query: &Query, supported_features: &str, ) -> azure_core::Result { let url = http_pipeline.url(items_link); - let mut request = pipeline::create_base_query_request(url, query)?; + let mut request = crate::connection::create_base_query_request(url, query)?; request.insert_header(constants::QUERY_ENABLE_CROSS_PARTITION, "True"); request.insert_header(constants::IS_QUERY_PLAN_REQUEST, "True"); request.insert_header( @@ -195,7 +195,7 @@ async fn get_query_plan( // This isn't an inherent method on QueryExecutor because that would force the whole executor to be Sync, which would force the pipeline to be Sync. #[tracing::instrument(skip_all)] async fn get_pkranges( - http_pipeline: &CosmosPipeline, + http_pipeline: &CosmosConnection, container_link: &ResourceLink, context: Context<'_>, ) -> azure_core::Result { diff --git a/sdk/cosmos/azure_data_cosmos/src/resource_context.rs b/sdk/cosmos/azure_data_cosmos/src/resource_context.rs index 69c61fe403..9e84871a99 100644 --- a/sdk/cosmos/azure_data_cosmos/src/resource_context.rs +++ b/sdk/cosmos/azure_data_cosmos/src/resource_context.rs @@ -5,7 +5,7 @@ use url::Url; use crate::utils::url_encode; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[allow(dead_code)] // For the variants. Can be removed when we have them all implemented. pub enum ResourceType { Databases, @@ -41,7 +41,7 @@ impl ResourceType { /// /// This value is URL encoded, and can be [`Url::join`]ed to the endpoint root to produce the full absolute URL for a Cosmos DB resource. /// It's also intended for use by the signature algorithm used when authenticating with a primary key. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ResourceLink { parent: Option, item_id: Option, diff --git a/sdk/cosmos/azure_data_cosmos/src/status.rs b/sdk/cosmos/azure_data_cosmos/src/status.rs new file mode 100644 index 0000000000..a9272017b1 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/status.rs @@ -0,0 +1,187 @@ +//! Defines the [`CosmosStatus`] type, which pairs an HTTP status code with a Cosmos sub-status code. + +use std::num::ParseIntError; + +use azure_core::{ + error::ErrorKind, + http::{headers::FromHeaders, StatusCode}, +}; + +use crate::constants; + +/// A Cosmos sub-status code, which provides additional information about the result of an operation. +/// +/// A specific sub-status code is often only meaningful in the context of a specific HTTP status code. +/// That is, sub-status code `x` may have a different meaning when paired with HTTP status code `A` than it does when paired with HTTP status code `B`. +/// +/// Constants on [`CosmosStatus`] provide the full source-of-truth meanings for specific HTTP status code and sub-status code combinations. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SubStatusCode(u16); + +impl SubStatusCode { + /// Creates a new `SubStatusCode` from a `u16`. + pub const fn new(value: u16) -> Self { + Self(value) + } + + /// Returns the inner `u16` value of the `SubStatusCode`. + pub const fn value(&self) -> u16 { + self.0 + } +} + +impl std::fmt::Display for SubStatusCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromHeaders for SubStatusCode { + type Error = ParseIntError; + + fn header_names() -> &'static [&'static str] { + // Right now, it's not feasible to extract the static str from HeaderName + &["x-ms-substatus"] + } + + fn from_headers( + headers: &azure_core::http::headers::Headers, + ) -> Result, Self::Error> { + let Some(s) = headers.get_optional_str(&constants::SUB_STATUS) else { + return Ok(None); + }; + let value = s.parse::()?; + Ok(Some(SubStatusCode::new(value))) + } +} + +/// Represents a Cosmos DB status, which is a combination of an HTTP status code and an optional Cosmos sub-status code. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CosmosStatus(StatusCode, Option); + +impl CosmosStatus { + pub const NAME_CACHE_IS_STALE: CosmosStatus = + CosmosStatus::new(StatusCode::Gone, Some(SubStatusCode::new(1000))); + + pub const fn new(status: StatusCode, substatus: Option) -> Self { + Self(status, substatus) + } + + pub const fn status_code(&self) -> StatusCode { + self.0 + } + + pub const fn substatus_code(&self) -> Option { + self.1 + } +} + +pub trait ErrorExt { + /// Fetches the [`CosmosStatus`] associated with this error, if any. + fn cosmos_status(&self) -> azure_core::Result>; +} + +impl ErrorExt for azure_core::Error { + fn cosmos_status(&self) -> azure_core::Result> { + match self.kind() { + ErrorKind::HttpResponse { + status, + raw_response, + .. + } => { + let substatus: Option = raw_response + .as_ref() + .and_then(|resp| resp.headers().get_optional().transpose()) + .transpose() + .map_err(|_| { + azure_core::Error::with_message( + ErrorKind::DataConversion, + "failed to parse substatus", + ) + })?; + Ok(Some(CosmosStatus::new(*status, substatus))) + } + _ => Ok(None), + } + } +} + +#[cfg(test)] +mod tests { + use azure_core::{ + error::ErrorKind, + http::{headers::Headers, RawResponse, StatusCode}, + }; + + use super::*; + + #[test] + fn cosmos_status_on_non_http_error() { + let err = + azure_core::Error::with_message(azure_core::error::ErrorKind::Other, "test error"); + assert!(err.cosmos_status().unwrap().is_none()); + } + + #[test] + fn cosmos_status_on_http_error_without_substatus() { + let headers = Headers::new(); + let response = RawResponse::from_bytes(StatusCode::Conflict, headers, Vec::new()); + let error = azure_core::Error::with_message( + ErrorKind::HttpResponse { + status: StatusCode::Conflict, + error_code: None, + raw_response: Some(Box::new(response)), + }, + "test error", + ); + assert_eq!( + error.cosmos_status().unwrap(), + Some(CosmosStatus::new(StatusCode::Conflict, None)) + ); + } + + #[test] + fn cosmos_status_on_http_error_with_substatus() { + let mut headers = Headers::new(); + headers.insert( + constants::SUB_STATUS, + CosmosStatus::NAME_CACHE_IS_STALE + .substatus_code() + .unwrap() + .to_string(), + ); + let response = RawResponse::from_bytes( + CosmosStatus::NAME_CACHE_IS_STALE.status_code(), + headers, + Vec::new(), + ); + let error = azure_core::Error::with_message( + ErrorKind::HttpResponse { + status: CosmosStatus::NAME_CACHE_IS_STALE.status_code(), + error_code: None, + raw_response: Some(Box::new(response)), + }, + "test error", + ); + assert_eq!( + error.cosmos_status().unwrap(), + Some(CosmosStatus::NAME_CACHE_IS_STALE) + ); + } + + #[test] + fn cosmos_status_with_invalid_substatus() { + let mut headers = Headers::new(); + headers.insert(constants::SUB_STATUS, "invalid"); + let response = RawResponse::from_bytes(StatusCode::Conflict, headers, Vec::new()); + let error = azure_core::Error::with_message( + ErrorKind::HttpResponse { + status: StatusCode::Conflict, + error_code: None, + raw_response: Some(Box::new(response)), + }, + "test error", + ); + assert!(error.cosmos_status().is_err()); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/types.rs b/sdk/cosmos/azure_data_cosmos/src/types.rs new file mode 100644 index 0000000000..4ea1e871f4 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/types.rs @@ -0,0 +1,47 @@ +//! Internal module to define several newtypes used in the SDK. + +macro_rules! string_newtype { + ($(#[$attr:meta])* $name:ident) => { + $(#[$attr])* + #[derive(serde::Deserialize, serde::Serialize, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] + #[serde(transparent)] + pub struct $name(String); + + impl $name { + #[doc = concat!("Creates a new `", stringify!($name), "` from a `String`.")] + pub fn new(value: String) -> Self { + Self(value) + } + + #[doc = concat!("Returns a reference to the inner `str` of the `", stringify!($name), "`.")] + pub fn value(&self) -> &str { + &self.0 + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + + impl From<&str> for $name { + fn from(s: &str) -> Self { + Self(s.to_string()) + } + } + + impl From for $name { + fn from(s: String) -> Self { + Self(s) + } + } + }; +} + +string_newtype!( + /// Represents a Resource ID, which is a unique identifier for a resource within a Cosmos DB account. + /// + /// In most cases, you don't need to use this type directly, as the SDK will handle resource IDs for you. + ResourceId +); diff --git a/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs b/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs index c655b7da4c..903be175b1 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs @@ -1,9 +1,14 @@ #![cfg(feature = "key_auth")] +#![allow( + unused_imports, + reason = "Each test builds this module separately and that means imports may be unused in certain builds." +)] mod framework; -use std::error::Error; +use std::{error::Error, sync::Arc}; +use azure_core::http::Method; use azure_core_test::{recorded, TestContext}; use azure_data_cosmos::{ models::{ @@ -14,7 +19,7 @@ use azure_data_cosmos::{ }; use futures::TryStreamExt; -use framework::{test_data, TestAccount}; +use framework::{test_data, LocalRecorder, TestAccount, TestAccountOptions}; #[recorded::test] pub async fn container_crud(context: TestContext) -> Result<(), Box> { @@ -237,3 +242,68 @@ pub async fn container_crud_hierarchical_pk(context: TestContext) -> Result<(), Ok(()) } + +#[recorded::test] +pub async fn container_read_throughput_twice(context: TestContext) -> Result<(), Box> { + let recorder = Arc::new(LocalRecorder::new()); + let account = TestAccount::from_env( + context, + Some(TestAccountOptions { + recorder: Some(recorder.clone()), + ..Default::default() + }), + ) + .await?; + + let cosmos_client = account.connect_with_key(None)?; + let db_client = test_data::create_database(&account, &cosmos_client).await?; + + let properties = ContainerProperties { + id: "ThroughputTestContainer".into(), + partition_key: "/id".into(), + ..Default::default() + }; + let throughput = ThroughputProperties::manual(600); + + db_client + .create_container( + properties.clone(), + Some(CreateContainerOptions { + throughput: Some(throughput), + ..Default::default() + }), + ) + .await? + .into_body()?; + let container_client = db_client.container_client(&properties.id); + + let first_throughput = container_client + .read_throughput(None) + .await? + .expect("throughput should be present") + .into_body()?; + assert_eq!(Some(600), first_throughput.throughput()); + + let second_throughput = container_client + .read_throughput(None) + .await? + .expect("throughput should be present") + .into_body()?; + assert_eq!(Some(600), second_throughput.throughput()); + + // Check the recorder to ensure only one request was made to read the container metadata + let txs = recorder.to_transactions().await; + assert_eq!( + 1, + txs.iter() + .filter(|t| t.request.method() == Method::Get + && t.request + .url() + .path() + .ends_with("/colls/ThroughputTestContainer")) + .count() + ); + + account.cleanup().await?; + Ok(()) +} diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs new file mode 100644 index 0000000000..c99c334b08 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use azure_core::http::{ + policies::{Policy, PolicyResult}, + BufResponse, Context, RawResponse, Request, +}; + +#[derive(Debug, Clone)] +pub struct Transaction { + pub request: Request, + pub response: Option, +} + +/// A policy that can be used to capture a simple local recording of requests for validation purposes +pub struct LocalRecorder { + transactions: tokio::sync::RwLock>, +} + +impl LocalRecorder { + pub fn new() -> Self { + Self { + transactions: tokio::sync::RwLock::new(Vec::new()), + } + } + + /// Returns a copy of all recorded transactions + pub async fn to_transactions(&self) -> Vec { + self.transactions.read().await.clone() + } +} + +impl std::fmt::Debug for LocalRecorder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LocalRecorder").finish() + } +} + +#[async_trait::async_trait] +impl Policy for LocalRecorder { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + let response = next[0].send(ctx, request, &next[1..]).await?; + let (status, headers, body) = response.deconstruct(); + let body = body.collect().await?; + let raw_response = RawResponse::from_bytes(status, headers.clone(), body.clone()); + self.transactions.write().await.push(Transaction { + request: request.clone(), + response: Some(raw_response.clone()), + }); + Ok(BufResponse::from_bytes(status, headers, body)) + } +} diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs index 8b204aafb3..6640340d07 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs @@ -1,20 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Some tests don't use all the features of this module. -#![allow(dead_code)] +#![allow( + dead_code, + reason = "Some tests don't use all the features of this module." +)] +#![allow( + unused_imports, + reason = "Some tests don't use all the re-exports from this module." +)] //! Provides a framework for integration tests for the Azure Cosmos DB service. //! //! The framework allows tests to easily run against real Cosmos DB instances, the local emulator, or a mock server using test-proxy. +mod local_recorder; mod test_account; pub mod test_data; #[cfg(feature = "preview_query_engine")] pub mod query_engine; -pub use test_account::TestAccount; +pub use local_recorder::LocalRecorder; +pub use test_account::{TestAccount, TestAccountOptions}; use serde::{Deserialize, Serialize}; diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs index 6cd8eeb215..55d52386e5 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs @@ -9,6 +9,8 @@ use azure_core_test::TestContext; use azure_data_cosmos::{ConnectionString, CosmosClientOptions, Query}; use reqwest::ClientBuilder; +use crate::framework::LocalRecorder; + /// Represents a Cosmos DB account for testing purposes. /// /// A [`TestAccount`] serves two main purposes: @@ -25,6 +27,7 @@ pub struct TestAccount { #[derive(Default)] pub struct TestAccountOptions { pub allow_invalid_certificates: Option, + pub recorder: Option>, } const CONNECTION_STRING_ENV_VAR: &str = "AZURE_COSMOS_CONNECTION_STRING"; @@ -113,6 +116,13 @@ impl TestAccount { .recording() .instrument(&mut options.client_options); + if let Some(recorder) = &self.options.recorder { + options + .client_options + .per_try_policies + .push(recorder.clone()); + } + Ok(azure_data_cosmos::CosmosClient::with_key( &self.endpoint, self.key.clone(),