Skip to content

Commit

Permalink
beautiful 400 error for duplicate firewall rule names
Browse files Browse the repository at this point in the history
  • Loading branch information
david-crespo committed Jun 18, 2024
1 parent 754fcd0 commit f873555
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 24 deletions.
36 changes: 33 additions & 3 deletions nexus/db-model/src/vpc_firewall_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use nexus_types::identity::Resource;
use omicron_common::api::external;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use std::collections::HashMap;
use std::io::Write;
use uuid::Uuid;

Expand Down Expand Up @@ -253,15 +255,43 @@ impl VpcFirewallRule {
pub fn vec_from_params(
vpc_id: Uuid,
params: external::VpcFirewallRuleUpdateParams,
) -> Vec<VpcFirewallRule> {
params
) -> Result<Vec<VpcFirewallRule>, external::Error> {
ensure_no_duplicates(&params)?;
Ok(params
.rules
.iter()
.map(|rule| VpcFirewallRule::new(Uuid::new_v4(), vpc_id, rule))
.collect()
.collect())
}
}

fn ensure_no_duplicates(
params: &external::VpcFirewallRuleUpdateParams,
) -> Result<(), external::Error> {
// we could do this by comparing set(names).len() to names.len(), but this way
// we can say what the duplicate names are, and that's nice!
let mut name_counts = HashMap::new();
for r in params.rules.iter() {
*name_counts.entry(r.name.to_string()).or_insert(0) += 1
}
let duplicates: Vec<String> = name_counts
.into_iter()
.filter(|(_, count)| *count > 1)
.map(|(name, _)| name)
.collect();

if !duplicates.is_empty() {
return Err(external::Error::invalid_value(
"rules",
format!(
"Rules names must be unique. Duplicates: {}",
json!(duplicates)
),
));
}
Ok(())
}

impl Into<external::VpcFirewallRule> for VpcFirewallRule {
fn into(self) -> external::VpcFirewallRule {
external::VpcFirewallRule {
Expand Down
5 changes: 3 additions & 2 deletions nexus/src/app/vpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ impl super::Nexus {
let rules = db::model::VpcFirewallRule::vec_from_params(
authz_vpc.id(),
params.clone(),
);
)?;

let rules = self
.db_datastore
.vpc_update_firewall_rules(opctx, &authz_vpc, rules)
Expand All @@ -199,7 +200,7 @@ impl super::Nexus {
let mut rules = db::model::VpcFirewallRule::vec_from_params(
vpc_id,
defaults::DEFAULT_FIREWALL_RULES.clone(),
);
)?;
for rule in rules.iter_mut() {
for target in rule.targets.iter_mut() {
match target.0 {
Expand Down
67 changes: 48 additions & 19 deletions nexus/tests/integration_tests/vpc_firewall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
use http::method::Method;
use http::StatusCode;
use nexus_test_utils::http_testing::{AuthnMode, NexusRequest};
use nexus_test_utils::resource_helpers::{create_project, create_vpc};
use nexus_test_utils::resource_helpers::{
create_project, create_vpc, object_get, object_put_error,
};
use nexus_test_utils_macros::nexus_test;
use nexus_types::external_api::views::Vpc;
use omicron_common::api::external::{
Expand Down Expand Up @@ -42,7 +44,9 @@ async fn test_vpc_firewall(cptestctx: &ControlPlaneTestContext) {

let default_vpc_firewall =
format!("/v1/vpc-firewall-rules?vpc=default&{}", project_selector,);
let rules = get_rules(client, &default_vpc_firewall).await;
let rules = object_get::<VpcFirewallRules>(client, &default_vpc_firewall)
.await
.rules;
assert!(rules.iter().all(|r| r.vpc_id == default_vpc.identity.id));
assert!(is_default_firewall_rules("default", &rules));

Expand All @@ -52,7 +56,8 @@ async fn test_vpc_firewall(cptestctx: &ControlPlaneTestContext) {
let other_vpc_firewall =
format!("/v1/vpc-firewall-rules?{}", other_vpc_selector);
let vpc2 = create_vpc(&client, &project_name, &other_vpc).await;
let rules = get_rules(client, &other_vpc_firewall).await;
let rules =
object_get::<VpcFirewallRules>(client, &other_vpc_firewall).await.rules;
assert!(rules.iter().all(|r| r.vpc_id == vpc2.identity.id));
assert!(is_default_firewall_rules(other_vpc, &rules));

Expand Down Expand Up @@ -111,14 +116,17 @@ async fn test_vpc_firewall(cptestctx: &ControlPlaneTestContext) {
assert_eq!(updated_rules[1].identity.name, "deny-all-incoming");

// Make sure the firewall is changed
let rules = get_rules(client, &default_vpc_firewall).await;
let rules = object_get::<VpcFirewallRules>(client, &default_vpc_firewall)
.await
.rules;
assert!(!is_default_firewall_rules("default", &rules));
assert_eq!(rules.len(), new_rules.len());
assert_eq!(rules[0].identity.name, "allow-icmp");
assert_eq!(rules[1].identity.name, "deny-all-incoming");

// Make sure the other firewall is unchanged
let rules = get_rules(client, &other_vpc_firewall).await;
let rules =
object_get::<VpcFirewallRules>(client, &other_vpc_firewall).await.rules;
assert!(is_default_firewall_rules(other_vpc, &rules));

// DELETE is unsupported
Expand Down Expand Up @@ -162,20 +170,6 @@ async fn test_vpc_firewall(cptestctx: &ControlPlaneTestContext) {
.unwrap();
}

async fn get_rules(
client: &dropshot::test_util::ClientTestContext,
url: &str,
) -> Vec<VpcFirewallRule> {
NexusRequest::object_get(client, url)
.authn_as(AuthnMode::PrivilegedUser)
.execute()
.await
.unwrap()
.parsed_body::<VpcFirewallRules>()
.unwrap()
.rules
}

fn is_default_firewall_rules(
vpc_name: &str,
rules: &Vec<VpcFirewallRule>,
Expand Down Expand Up @@ -292,3 +286,38 @@ fn is_default_firewall_rules(
}
true
}

#[nexus_test]
async fn test_firewall_rules_same_name(cptestctx: &ControlPlaneTestContext) {
let client = &cptestctx.external_client;

let project_name = "my-project";
create_project(&client, &project_name).await;

let rule = VpcFirewallRuleUpdate {
name: "dupe".parse().unwrap(),
description: "".to_string(),
status: VpcFirewallRuleStatus::Enabled,
direction: VpcFirewallRuleDirection::Inbound,
targets: vec![],
filters: VpcFirewallRuleFilter {
hosts: None,
protocols: None,
ports: None,
},
action: VpcFirewallRuleAction::Allow,
priority: VpcFirewallRulePriority(65534),
};

let error = object_put_error(
client,
&format!("/v1/vpc-firewall-rules?vpc=default&project={}", project_name),
&VpcFirewallRuleUpdateParams {
rules: vec![rule.clone(), rule.clone()],
},
StatusCode::BAD_REQUEST,
)
.await;
assert_eq!(error.error_code, Some("InvalidValue".to_string()));
assert_eq!(error.message, "unsupported value for \"rules\": Rules names must be unique. Duplicates: [\"dupe\"]");
}

0 comments on commit f873555

Please sign in to comment.