diff --git a/sources/Cargo.lock b/sources/Cargo.lock index 0e4cb42bc..2d59b9c9a 100644 --- a/sources/Cargo.lock +++ b/sources/Cargo.lock @@ -3148,6 +3148,7 @@ dependencies = [ "generate-readme", "headers", "http 0.2.12", + "httptest", "hyper", "hyper-rustls", "imdsclient", diff --git a/sources/api/pluto/Cargo.toml b/sources/api/pluto/Cargo.toml index 89a706d75..deafa1c62 100644 --- a/sources/api/pluto/Cargo.toml +++ b/sources/api/pluto/Cargo.toml @@ -40,3 +40,6 @@ version = "0.2.0" [build-dependencies] generate-readme = { version = "0.1", path = "../../generate-readme" } + +[dev-dependencies] +httptest = "0.15" diff --git a/sources/api/pluto/src/api.rs b/sources/api/pluto/src/api.rs index e4b06276f..93e3c4ece 100644 --- a/sources/api/pluto/src/api.rs +++ b/sources/api/pluto/src/api.rs @@ -27,6 +27,9 @@ pub(crate) struct AwsK8sInfo { pub(crate) provider_id: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub(crate) hostname_override: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) hostname_override_source: + Option, #[serde(skip)] pub(crate) variant_id: String, } @@ -53,6 +56,9 @@ pub(crate) struct Kubernetes { pub(crate) provider_id: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub(crate) hostname_override: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) hostname_override_source: + Option, } #[derive(Serialize, Deserialize)] @@ -173,5 +179,10 @@ pub(crate) async fn get_aws_k8s_info() -> Result { .kubernetes .as_ref() .and_then(|k| k.hostname_override.clone()), + hostname_override_source: view + .settings + .kubernetes + .as_ref() + .and_then(|k| k.hostname_override_source.clone()), }) } diff --git a/sources/api/pluto/src/aws.rs b/sources/api/pluto/src/aws.rs index ec29560d9..49b7c330d 100644 --- a/sources/api/pluto/src/aws.rs +++ b/sources/api/pluto/src/aws.rs @@ -25,7 +25,7 @@ pub(crate) async fn sdk_config(region: &str) -> SdkConfig { .imds_client(sdk_imds_client()) .build() .await; - aws_config::defaults(BehaviorVersion::v2023_11_09()) + aws_config::defaults(BehaviorVersion::v2024_03_28()) .region(Region::new(region.to_owned())) .credentials_provider(provider) .retry_config(sdk_retry_config()) diff --git a/sources/api/pluto/src/main.rs b/sources/api/pluto/src/main.rs index 6c6db8cc0..9d2023d41 100644 --- a/sources/api/pluto/src/main.rs +++ b/sources/api/pluto/src/main.rs @@ -39,7 +39,7 @@ mod hyper_proxy; mod proxy; use api::AwsK8sInfo; -use bottlerocket_modeled_types::KubernetesClusterDnsIp; +use bottlerocket_modeled_types::{KubernetesClusterDnsIp, KubernetesHostnameOverrideSource}; use imdsclient::ImdsClient; use snafu::{ensure, OptionExt, ResultExt}; use std::fs::File; @@ -346,14 +346,26 @@ async fn generate_provider_id( Ok(()) } -async fn generate_private_dns_name( - client: &mut ImdsClient, - aws_k8s_info: &mut AwsK8sInfo, -) -> Result<()> { - if aws_k8s_info.hostname_override.is_some() || NO_HOSTNAME_VARIANTS.contains(&aws_k8s_info.variant_id.as_str()) { +/// generate_node_name sets the hostname_override, if it is not already specified +async fn generate_node_name(client: &mut ImdsClient, aws_k8s_info: &mut AwsK8sInfo) -> Result<()> { + // hostname override provided, so we do nothing + if aws_k8s_info.hostname_override.is_some() { return Ok(()); } + // no override source provided, and this is one of the no hostname variants + if aws_k8s_info.hostname_override_source.is_none() + && NO_HOSTNAME_VARIANTS.contains(&aws_k8s_info.variant_id.as_str()) + { + return Ok(()); + } + + // use the hostname source provided, defaulting to the private DNS name + let hostname_source = aws_k8s_info + .hostname_override_source + .clone() + .unwrap_or(KubernetesHostnameOverrideSource::PrivateDNSName); + let region = aws_k8s_info .region .as_ref() @@ -365,16 +377,25 @@ async fn generate_private_dns_name( .context(error::ImdsNoneSnafu { what: "instance ID", })?; - aws_k8s_info.hostname_override = Some( - ec2::get_private_dns_name( - region, - &instance_id, - aws_k8s_info.https_proxy.clone(), - aws_k8s_info.no_proxy.clone(), - ) - .await - .context(error::Ec2Snafu)?, - ); + + match hostname_source { + KubernetesHostnameOverrideSource::PrivateDNSName => { + aws_k8s_info.hostname_override = Some( + ec2::get_private_dns_name( + region, + &instance_id, + aws_k8s_info.https_proxy.clone(), + aws_k8s_info.no_proxy.clone(), + ) + .await + .context(error::Ec2Snafu)?, + ); + } + KubernetesHostnameOverrideSource::InstanceID => { + aws_k8s_info.hostname_override = Some(instance_id); + } + } + Ok(()) } @@ -386,7 +407,7 @@ async fn run() -> Result<()> { generate_node_ip(&mut client, &mut aws_k8s_info).await?; generate_max_pods(&mut client, &mut aws_k8s_info).await?; generate_provider_id(&mut client, &mut aws_k8s_info).await?; - generate_private_dns_name(&mut client, &mut aws_k8s_info).await?; + generate_node_name(&mut client, &mut aws_k8s_info).await?; let settings = serde_json::to_value(&aws_k8s_info).context(error::SerializeSnafu)?; let generated_settings: serde_json::Value = serde_json::json!({ @@ -416,17 +437,115 @@ async fn main() { } } -#[test] -fn test_get_dns_from_cidr_ok() { - let input = "123.456.789.0/123"; - let expected = "123.456.789.10"; - let actual = get_dns_from_ipv4_cidr(input).unwrap(); - assert_eq!(expected, actual); -} +#[cfg(test)] +mod test { + use super::*; + use httptest::{matchers::*, responders::*, Expectation, Server}; + + #[test] + fn test_get_dns_from_cidr_ok() { + let input = "123.456.789.0/123"; + let expected = "123.456.789.10"; + let actual = get_dns_from_ipv4_cidr(input).unwrap(); + assert_eq!(expected, actual); + } -#[test] -fn test_get_dns_from_cidr_err() { - let input = "123_456_789_0/123"; - let result = get_dns_from_ipv4_cidr(input); - assert!(result.is_err()); + #[test] + fn test_get_dns_from_cidr_err() { + let input = "123_456_789_0/123"; + let result = get_dns_from_ipv4_cidr(input); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_hostname_override_source() { + let server = Server::run(); + let base_uri = format!("http://{}", server.addr()); + println!("listen on {}", base_uri); + let token = "some+token"; + let schema_version = "2021-07-15"; + let target = "meta-data/instance-id"; + let response_code = 200; + let response_body = "i-123456789"; + server.expect( + Expectation::matching(request::method_path("PUT", "/latest/api/token")) + .times(1) + .respond_with( + status_code(200) + .append_header("X-aws-ec2-metadata-token-ttl-seconds", "60") + .body(token), + ), + ); + server.expect( + Expectation::matching(request::method_path( + "GET", + format!("/{}/{}", schema_version, target), + )) + .times(1) + .respond_with( + status_code(response_code) + .append_header("X-aws-ec2-metadata-token", token) + .body(response_body), + ), + ); + + let mut imds_client = ImdsClient::new_impl(base_uri); + + let mut info = AwsK8sInfo { + region: Some(String::from("us-west-2")), + https_proxy: None, + no_proxy: None, + cluster_name: None, + cluster_dns_ip: None, + node_ip: None, + max_pods: None, + provider_id: None, + hostname_override: None, + hostname_override_source: None, + variant_id: "".to_string(), + }; + + // specifying a hostname will cause it to be used + info.hostname_override = Some(String::from("hostname-specified")); + generate_node_name(&mut imds_client, &mut info) + .await + .unwrap(); + assert_eq!( + info.hostname_override, + Some(String::from("hostname-specified")) + ); + + // regardless of the hostname override source + info.hostname_override = Some(String::from("hostname-specified")); + info.hostname_override_source = Some(KubernetesHostnameOverrideSource::InstanceID); + generate_node_name(&mut imds_client, &mut info) + .await + .unwrap(); + assert_eq!( + info.hostname_override, + Some(String::from("hostname-specified")) + ); + + // no override with an old K8s, results in no setting of the hostname_override + info.hostname_override = None; + info.hostname_override_source = None; + info.variant_id = String::from("aws-k8s-1.23"); + generate_node_name(&mut imds_client, &mut info) + .await + .unwrap(); + assert_eq!(info.hostname_override, None); + + // skipping tests that call use the private dns name since we would need to make the EC2 + // API mockable to implement them + + // specifying no hostname, with override of instance-id causes the instance-id to be used + // and pulled from IMDS + info.hostname_override = None; + info.hostname_override_source = Some(KubernetesHostnameOverrideSource::InstanceID); + info.variant_id = String::from("aws-k8s-1.29"); + generate_node_name(&mut imds_client, &mut info) + .await + .unwrap(); + assert_eq!(info.hostname_override, Some(String::from("i-123456789"))); + } } diff --git a/sources/imdsclient/src/lib.rs b/sources/imdsclient/src/lib.rs index 6f3fb62e5..b39c52d63 100644 --- a/sources/imdsclient/src/lib.rs +++ b/sources/imdsclient/src/lib.rs @@ -60,7 +60,8 @@ impl ImdsClient { Self::new_impl(BASE_URI.to_string()) } - fn new_impl(imds_base_uri: String) -> Self { + /// Exposed solely to allow unit testing. + pub fn new_impl(imds_base_uri: String) -> Self { Self { client: Client::new(), retry_timeout: Duration::from_secs(RETRY_TIMEOUT_SECS),