diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8044b2f47..008158fb0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - uses: sfackler/actions/rustfmt@master - + clippy: name: clippy runs-on: ubuntu-latest @@ -47,6 +47,33 @@ jobs: key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - run: cargo clippy --all --all-targets + check-wasm32: + name: check-wasm32 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: sfackler/actions/rustup@master + - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + id: rust-version + - run: rustup target add wasm32-unknown-unknown + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}- + - run: cargo generate-lockfile + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v3 + with: + path: target + key: check-wasm32-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js + test: name: test runs-on: ubuntu-latest @@ -55,7 +82,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.64.0 + version: 1.70.0 - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version - uses: actions/cache@v3 diff --git a/Cargo.toml b/Cargo.toml index 4752836a7..16e3739dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "postgres-types", "tokio-postgres", ] +resolver = "2" [profile.release] debug = 2 diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs new file mode 100644 index 000000000..52d0ba8f6 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs @@ -0,0 +1,31 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchStruct { + a: i32, +} + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchStruct { + a: i32, +} + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(transparent, allow_mismatch)] +struct TransparentFromSqlAllowMismatchStruct(i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch, transparent)] +struct AllowMismatchFromSqlTransparentStruct(i32); + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr new file mode 100644 index 000000000..a8e573248 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr @@ -0,0 +1,43 @@ +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:4:1 + | +4 | / #[postgres(allow_mismatch)] +5 | | struct ToSqlAllowMismatchStruct { +6 | | a: i32, +7 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:10:1 + | +10 | / #[postgres(allow_mismatch)] +11 | | struct FromSqlAllowMismatchStruct { +12 | | a: i32, +13 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:16:1 + | +16 | / #[postgres(allow_mismatch)] +17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32); + | |_______________________________________________^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:20:1 + | +20 | / #[postgres(allow_mismatch)] +21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32); + | |_________________________________________________^ + +error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)] + --> src/compile-fail/invalid-allow-mismatch.rs:24:25 + | +24 | #[postgres(transparent, allow_mismatch)] + | ^^^^^^^^^^^^^^ + +error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)] + --> src/compile-fail/invalid-allow-mismatch.rs:28:28 + | +28 | #[postgres(allow_mismatch, transparent)] + | ^^^^^^^^^^^ diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index a1b76345f..50a22790d 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -89,6 +89,49 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item", rename_all = "SCREAMING_SNAKE_CASE")] + struct InventoryItem { + name: String, + supplier_id: i32, + #[postgres(name = "Price")] + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + \"NAME\" TEXT, + \"SUPPLIER_ID\" INT, + \"Price\" DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let item_null = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: None, + }; + + test_type( + &mut conn, + "inventory_item", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + #[test] fn wrong_name() { #[derive(FromSql, ToSql, Debug, PartialEq)] diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index a7039ca05..f3e6c488c 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -1,5 +1,5 @@ use crate::test_type; -use postgres::{Client, NoTls}; +use postgres::{error::DbError, Client, NoTls}; use postgres_types::{FromSql, ToSql, WrongType}; use std::error::Error; @@ -53,6 +53,35 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "mood", rename_all = "snake_case")] + enum Mood { + VerySad, + #[postgres(name = "okay")] + Ok, + VeryHappy, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.mood AS ENUM ('very_sad', 'okay', 'very_happy')", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "mood", + &[ + (Mood::VerySad, "'very_sad'"), + (Mood::Ok, "'okay'"), + (Mood::VeryHappy, "'very_happy'"), + ], + ); +} + #[test] fn wrong_name() { #[derive(Debug, ToSql, FromSql, PartialEq)] @@ -102,3 +131,73 @@ fn missing_variant() { let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); assert!(err.source().unwrap().is::()); } + +#[test] +fn allow_mismatch_enums() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Bar); +} + +#[test] +fn missing_enum_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn + .query_one("SELECT $1::\"Foo\"", &[&Foo::Buz]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn allow_mismatch_and_renaming() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo", allow_mismatch)] + enum Foo { + #[postgres(name = "bar")] + Bar, + #[postgres(name = "buz")] + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Buz); +} + +#[test] +fn wrong_name_and_allow_mismatch() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} diff --git a/postgres-derive/CHANGELOG.md b/postgres-derive/CHANGELOG.md index 22714acc2..b0075fa8e 100644 --- a/postgres-derive/CHANGELOG.md +++ b/postgres-derive/CHANGELOG.md @@ -1,5 +1,12 @@ # Change Log +## v0.4.5 - 2023-08-19 + +### Added + +* Added a `rename_all` option for enum and struct derives. +* Added an `allow_mismatch` option to disable strict enum variant checks against the Postgres type. + ## v0.4.4 - 2023-03-27 ### Changed diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 535a64315..51ebb5663 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-derive" -version = "0.4.4" +version = "0.4.5" authors = ["Steven Fackler "] license = "MIT/Apache-2.0" edition = "2018" @@ -15,3 +15,4 @@ test = false syn = "2.0" proc-macro2 = "1.0" quote = "1.0" +heck = "0.4" \ No newline at end of file diff --git a/postgres-derive/src/accepts.rs b/postgres-derive/src/accepts.rs index 63473863a..a68538dcc 100644 --- a/postgres-derive/src/accepts.rs +++ b/postgres-derive/src/accepts.rs @@ -31,31 +31,37 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream { } } -pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream { +pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream { let num_variants = variants.len(); let variant_names = variants.iter().map(|v| &v.name); - quote! { - if type_.name() != #name { - return false; + if allow_mismatch { + quote! { + type_.name() == #name } + } else { + quote! { + if type_.name() != #name { + return false; + } - match *type_.kind() { - ::postgres_types::Kind::Enum(ref variants) => { - if variants.len() != #num_variants { - return false; - } - - variants.iter().all(|v| { - match &**v { - #( - #variant_names => true, - )* - _ => false, + match *type_.kind() { + ::postgres_types::Kind::Enum(ref variants) => { + if variants.len() != #num_variants { + return false; } - }) + + variants.iter().all(|v| { + match &**v { + #( + #variant_names => true, + )* + _ => false, + } + }) + } + _ => false, } - _ => false, } } } diff --git a/postgres-derive/src/case.rs b/postgres-derive/src/case.rs new file mode 100644 index 000000000..20ecc8eed --- /dev/null +++ b/postgres-derive/src/case.rs @@ -0,0 +1,110 @@ +#[allow(deprecated, unused_imports)] +use std::ascii::AsciiExt; + +use heck::{ + ToKebabCase, ToLowerCamelCase, ToShoutyKebabCase, ToShoutySnakeCase, ToSnakeCase, ToTrainCase, + ToUpperCamelCase, +}; + +use self::RenameRule::*; + +/// The different possible ways to change case of fields in a struct, or variants in an enum. +#[allow(clippy::enum_variant_names)] +#[derive(Copy, Clone, PartialEq)] +pub enum RenameRule { + /// Rename direct children to "lowercase" style. + LowerCase, + /// Rename direct children to "UPPERCASE" style. + UpperCase, + /// Rename direct children to "PascalCase" style, as typically used for + /// enum variants. + PascalCase, + /// Rename direct children to "camelCase" style. + CamelCase, + /// Rename direct children to "snake_case" style, as commonly used for + /// fields. + SnakeCase, + /// Rename direct children to "SCREAMING_SNAKE_CASE" style, as commonly + /// used for constants. + ScreamingSnakeCase, + /// Rename direct children to "kebab-case" style. + KebabCase, + /// Rename direct children to "SCREAMING-KEBAB-CASE" style. + ScreamingKebabCase, + + /// Rename direct children to "Train-Case" style. + TrainCase, +} + +pub const RENAME_RULES: &[&str] = &[ + "lowercase", + "UPPERCASE", + "PascalCase", + "camelCase", + "snake_case", + "SCREAMING_SNAKE_CASE", + "kebab-case", + "SCREAMING-KEBAB-CASE", + "Train-Case", +]; + +impl RenameRule { + pub fn from_str(rule: &str) -> Option { + match rule { + "lowercase" => Some(LowerCase), + "UPPERCASE" => Some(UpperCase), + "PascalCase" => Some(PascalCase), + "camelCase" => Some(CamelCase), + "snake_case" => Some(SnakeCase), + "SCREAMING_SNAKE_CASE" => Some(ScreamingSnakeCase), + "kebab-case" => Some(KebabCase), + "SCREAMING-KEBAB-CASE" => Some(ScreamingKebabCase), + "Train-Case" => Some(TrainCase), + _ => None, + } + } + /// Apply a renaming rule to an enum or struct field, returning the version expected in the source. + pub fn apply_to_field(&self, variant: &str) -> String { + match *self { + LowerCase => variant.to_lowercase(), + UpperCase => variant.to_uppercase(), + PascalCase => variant.to_upper_camel_case(), + CamelCase => variant.to_lower_camel_case(), + SnakeCase => variant.to_snake_case(), + ScreamingSnakeCase => variant.to_shouty_snake_case(), + KebabCase => variant.to_kebab_case(), + ScreamingKebabCase => variant.to_shouty_kebab_case(), + TrainCase => variant.to_train_case(), + } + } +} + +#[test] +fn rename_field() { + for &(original, lower, upper, camel, snake, screaming, kebab, screaming_kebab) in &[ + ( + "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", + ), + ( + "VeryTasty", + "verytasty", + "VERYTASTY", + "veryTasty", + "very_tasty", + "VERY_TASTY", + "very-tasty", + "VERY-TASTY", + ), + ("A", "a", "A", "a", "a", "A", "a", "A"), + ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), + ] { + assert_eq!(LowerCase.apply_to_field(original), lower); + assert_eq!(UpperCase.apply_to_field(original), upper); + assert_eq!(PascalCase.apply_to_field(original), original); + assert_eq!(CamelCase.apply_to_field(original), camel); + assert_eq!(SnakeCase.apply_to_field(original), snake); + assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); + assert_eq!(KebabCase.apply_to_field(original), kebab); + assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); + } +} diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index 15bfabc13..b6aad8ab3 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -4,7 +4,7 @@ use syn::{ TypeParamBound, }; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Field { pub name: String, @@ -13,18 +13,26 @@ pub struct Field { } impl Field { - pub fn parse(raw: &syn::Field) -> Result { - let overrides = Overrides::extract(&raw.attrs)?; - + pub fn parse(raw: &syn::Field, rename_all: Option) -> Result { + let overrides = Overrides::extract(&raw.attrs, false)?; let ident = raw.ident.as_ref().unwrap().clone(); - Ok(Field { - name: overrides.name.unwrap_or_else(|| { + + // field level name override takes precendence over container level rename_all override + let name = match overrides.name { + Some(n) => n, + None => { let name = ident.to_string(); - match name.strip_prefix("r#") { - Some(name) => name.to_string(), - None => name, + let stripped = name.strip_prefix("r#").map(String::from).unwrap_or(name); + + match rename_all { + Some(rule) => rule.apply_to_field(&stripped), + None => stripped, } - }), + } + }; + + Ok(Field { + name, ident, type_: raw.ty.clone(), }) diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index 3c6bc7113..9a6dfa926 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -1,6 +1,6 @@ use syn::{Error, Fields, Ident}; -use crate::overrides::Overrides; +use crate::{case::RenameRule, overrides::Overrides}; pub struct Variant { pub ident: Ident, @@ -8,7 +8,7 @@ pub struct Variant { } impl Variant { - pub fn parse(raw: &syn::Variant) -> Result { + pub fn parse(raw: &syn::Variant, rename_all: Option) -> Result { match raw.fields { Fields::Unit => {} _ => { @@ -18,11 +18,16 @@ impl Variant { )) } } + let overrides = Overrides::extract(&raw.attrs, false)?; - let overrides = Overrides::extract(&raw.attrs)?; + // variant level name override takes precendence over container level rename_all override + let name = overrides.name.unwrap_or_else(|| match rename_all { + Some(rule) => rule.apply_to_field(&raw.ident.to_string()), + None => raw.ident.to_string(), + }); Ok(Variant { ident: raw.ident.clone(), - name: overrides.name.unwrap_or_else(|| raw.ident.to_string()), + name, }) } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index bb87ded5f..d3ac47f4f 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -15,16 +15,19 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_fromsql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -45,16 +48,36 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )) } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } @@ -75,7 +98,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "FromSql", &fields), diff --git a/postgres-derive/src/lib.rs b/postgres-derive/src/lib.rs index 98e6add24..b849096c9 100644 --- a/postgres-derive/src/lib.rs +++ b/postgres-derive/src/lib.rs @@ -7,6 +7,7 @@ use proc_macro::TokenStream; use syn::parse_macro_input; mod accepts; +mod case; mod composites; mod enums; mod fromsql; diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index ddb37688b..d50550bee 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -1,16 +1,22 @@ use syn::punctuated::Punctuated; use syn::{Attribute, Error, Expr, ExprLit, Lit, Meta, Token}; +use crate::case::{RenameRule, RENAME_RULES}; + pub struct Overrides { pub name: Option, + pub rename_all: Option, pub transparent: bool, + pub allow_mismatch: bool, } impl Overrides { - pub fn extract(attrs: &[Attribute]) -> Result { + pub fn extract(attrs: &[Attribute], container_attr: bool) -> Result { let mut overrides = Overrides { name: None, + rename_all: None, transparent: false, + allow_mismatch: false, }; for attr in attrs { @@ -28,7 +34,15 @@ impl Overrides { for item in nested { match item { Meta::NameValue(meta) => { - if !meta.path.is_ident("name") { + let name_override = meta.path.is_ident("name"); + let rename_all_override = meta.path.is_ident("rename_all"); + if !container_attr && rename_all_override { + return Err(Error::new_spanned( + &meta.path, + "rename_all is a container attribute", + )); + } + if !name_override && !rename_all_override { return Err(Error::new_spanned(&meta.path, "unknown override")); } @@ -41,14 +55,46 @@ impl Overrides { } }; - overrides.name = Some(value); + if name_override { + overrides.name = Some(value); + } else if rename_all_override { + let rename_rule = RenameRule::from_str(&value).ok_or_else(|| { + Error::new_spanned( + &meta.value, + format!( + "invalid rename_all rule, expected one of: {}", + RENAME_RULES + .iter() + .map(|rule| format!("\"{}\"", rule)) + .collect::>() + .join(", ") + ), + ) + })?; + + overrides.rename_all = Some(rename_rule); + } } Meta::Path(path) => { - if !path.is_ident("transparent") { + if path.is_ident("transparent") { + if overrides.allow_mismatch { + return Err(Error::new_spanned( + path, + "#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]", + )); + } + overrides.transparent = true; + } else if path.is_ident("allow_mismatch") { + if overrides.transparent { + return Err(Error::new_spanned( + path, + "#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]", + )); + } + overrides.allow_mismatch = true; + } else { return Err(Error::new_spanned(path, "unknown override")); } - - overrides.transparent = true; } bad => return Err(Error::new_spanned(bad, "unknown attribute")), } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index e51acc7fd..81d4834bf 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -13,16 +13,19 @@ use crate::enums::Variant; use crate::overrides::Overrides; pub fn expand_derive_tosql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let overrides = Overrides::extract(&input.attrs, true)?; - if overrides.name.is_some() && overrides.transparent { + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { return Err(Error::new_spanned( &input, - "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")]", + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -41,16 +44,36 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } @@ -69,7 +92,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( accepts::composite_body(&name, "ToSql", &fields), diff --git a/postgres-protocol/CHANGELOG.md b/postgres-protocol/CHANGELOG.md index 034fd637c..1c371675c 100644 --- a/postgres-protocol/CHANGELOG.md +++ b/postgres-protocol/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.6.6 -2023-08-19 + +### Added + +* Added the `js` feature for WASM support. + ## v0.6.5 - 2023-03-27 ### Added diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index e32211369..b44994811 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-protocol" -version = "0.6.5" +version = "0.6.6" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" @@ -8,6 +8,10 @@ license = "MIT/Apache-2.0" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" +[features] +default = [] +js = ["getrandom/js"] + [dependencies] base64 = "0.21" byteorder = "1.0" @@ -19,3 +23,4 @@ memchr = "2.0" rand = "0.8" sha2 = "0.10" stringprep = "0.1" +getrandom = { version = "0.2", optional = true } diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 8b6ff508d..83d9bf55c 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -9,7 +9,6 @@ //! //! This library assumes that the `client_encoding` backend parameter has been //! set to `UTF8`. It will most likely not behave properly if that is not the case. -#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")] #![warn(missing_docs, rust_2018_idioms, clippy::all)] use byteorder::{BigEndian, ByteOrder}; diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 6f1851fc2..3e33b08f0 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -174,7 +174,7 @@ fn ltree_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -182,7 +182,7 @@ fn ltree_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } #[test] @@ -202,7 +202,7 @@ fn lquery_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) + assert!(lquery_from_sql(query.as_slice()).is_ok()) } #[test] @@ -210,7 +210,7 @@ fn lquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) + assert!(lquery_from_sql(query.as_slice()).is_err()) } #[test] @@ -230,7 +230,7 @@ fn ltxtquery_str() { let mut query = vec![1u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -238,5 +238,5 @@ fn ltxtquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } diff --git a/postgres-types/CHANGELOG.md b/postgres-types/CHANGELOG.md index 0f42f3495..72a1cbb6a 100644 --- a/postgres-types/CHANGELOG.md +++ b/postgres-types/CHANGELOG.md @@ -1,14 +1,25 @@ # Change Log +## v0.2.6 - 2023-08-19 + +### Fixed + +* Fixed serialization to `OIDVECTOR` and `INT2VECTOR`. + +### Added + +* Removed the `'static` requirement for the `impl BorrowToSql for Box`. +* Added a `ToSql` implementation for `Cow<[u8]>`. + ## v0.2.5 - 2023-03-27 -## Added +### Added * Added support for multi-range types. ## v0.2.4 - 2022-08-20 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `Box<[T]>`. * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 35cdd6e7b..193d159a1 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-types" -version = "0.2.5" +version = "0.2.6" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -30,8 +30,8 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" fallible-iterator = "0.2" -postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } -postgres-derive = { version = "0.4.2", optional = true, path = "../postgres-derive" } +postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" } +postgres-derive = { version = "0.4.5", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..52b5c773a 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -125,9 +125,56 @@ //! Happy, //! } //! ``` -#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] +//! +//! Alternatively, the `#[postgres(rename_all = "...")]` attribute can be used to rename all fields or variants +//! with the chosen casing convention. This will not affect the struct or enum's type name. Note that +//! `#[postgres(name = "...")]` takes precendence when used in conjunction with `#[postgres(rename_all = "...")]`: +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(name = "mood", rename_all = "snake_case")] +//! enum Mood { +//! #[postgres(name = "ok")] +//! Ok, // ok +//! VeryHappy, // very_happy +//! } +//! ``` +//! +//! The following case conventions are supported: +//! - `"lowercase"` +//! - `"UPPERCASE"` +//! - `"PascalCase"` +//! - `"camelCase"` +//! - `"snake_case"` +//! - `"SCREAMING_SNAKE_CASE"` +//! - `"kebab-case"` +//! - `"SCREAMING-KEBAB-CASE"` +//! - `"Train-Case"` +//! +//! ## Allowing Enum Mismatches +//! +//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum +//! variants between the Rust and Postgres types. +//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition: +//! +//! ```sql +//! CREATE TYPE mood AS ENUM ( +//! 'Sad', +//! 'Ok', +//! 'Happy' +//! ); +//! ``` +//! #[postgres(allow_mismatch)] +//! enum Mood { +//! Happy, +//! Meh, +//! } +//! ``` #![warn(clippy::all, rust_2018_idioms, missing_docs)] - use fallible_iterator::FallibleIterator; use postgres_protocol::types::{self, ArrayDimension}; use std::any::type_name; @@ -910,9 +957,15 @@ impl<'a, T: ToSql> ToSql for &'a [T] { _ => panic!("expected array type"), }; + // Arrays are normally one indexed by default but oidvector and int2vector *require* zero indexing + let lower_bound = match *ty { + Type::OID_VECTOR | Type::INT2_VECTOR => 0, + _ => 1, + }; + let dimension = ArrayDimension { len: downcast(self.len())?, - lower_bound: 1, + lower_bound, }; types::array_to_sql( @@ -998,6 +1051,18 @@ impl ToSql for Box<[T]> { to_sql_checked!(); } +impl<'a> ToSql for Cow<'a, [u8]> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&[u8] as ToSql>::to_sql(&self.as_ref(), ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[u8] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + impl ToSql for Vec { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { <&[u8] as ToSql>::to_sql(&&**self, ty, w) @@ -1012,28 +1077,20 @@ impl ToSql for Vec { impl<'a> ToSql for &'a str { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - match *ty { - ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), - ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), - ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), + match ty.name() { + "ltree" => types::ltree_to_sql(self, w), + "lquery" => types::lquery_to_sql(self, w), + "ltxtquery" => types::ltxtquery_to_sql(self, w), _ => types::text_to_sql(self, w), } Ok(IsNull::No) } fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty - if (ty.name() == "citext" - || ty.name() == "ltree" - || ty.name() == "lquery" - || ty.name() == "ltxtquery") => - { - true - } - _ => false, - } + matches!( + *ty, + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN + ) || matches!(ty.name(), "citext" | "ltree" | "lquery" | "ltxtquery") } to_sql_checked!(); @@ -1186,17 +1243,17 @@ impl BorrowToSql for &dyn ToSql { } } -impl sealed::Sealed for Box {} +impl<'a> sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() } } -impl sealed::Sealed for Box {} -impl BorrowToSql for Box { +impl<'a> sealed::Sealed for Box {} +impl<'a> BorrowToSql for Box { #[inline] fn borrow_to_sql(&self) -> &dyn ToSql { self.as_ref() diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md index b8263a04a..7f856b5ac 100644 --- a/postgres/CHANGELOG.md +++ b/postgres/CHANGELOG.md @@ -1,20 +1,34 @@ # Change Log +## v0.19.7 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + +## v0.19.6 - 2023-08-19 + +### Added + +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + ## v0.19.5 - 2023-03-27 -## Added +### Added * Added `keepalives_interval` and `keepalives_retries` config options. * Added the `tcp_user_timeout` config option. * Added `RowIter::rows_affected`. -## Changed +### Changed * Passing an incorrect number of parameters to a query method now returns an error instead of panicking. ## v0.19.4 - 2022-08-21 -## Added +### Added * Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. * Added support for `smol_str` 0.1 via the `with-smol_str-01` feature. diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index e0b2a249d..18406da9f 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres" -version = "0.19.5" +version = "0.19.7" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -39,11 +39,9 @@ with-time-0_3 = ["tokio-postgres/with-time-0_3"] bytes = "1.0" fallible-iterator = "0.2" futures-util = { version = "0.3.14", features = ["sink"] } -tokio-postgres = { version = "0.7.8", path = "../tokio-postgres" } - -tokio = { version = "1.0", features = ["rt", "time"] } log = "0.4" +tokio-postgres = { version = "0.7.10", path = "../tokio-postgres" } +tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] -criterion = "0.4" -tokio = { version = "1.0", features = ["rt-multi-thread"] } +criterion = "0.5" diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 95c5ea417..a32ddc78e 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -1,18 +1,19 @@ //! Connection configuration. -//! -//! Requires the `runtime` Cargo feature (enabled by default). use crate::connection::Connection; use crate::Client; use log::info; use std::fmt; +use std::net::IpAddr; use std::path::Path; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use tokio::runtime; #[doc(inline)] -pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{ + ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs, +}; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; @@ -28,7 +29,7 @@ use tokio_postgres::{Error, Socket}; /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -39,6 +40,19 @@ use tokio_postgres::{Error, Socket}; /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for TLS certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -58,6 +72,15 @@ use tokio_postgres::{Error, Socket}; /// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that /// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server /// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -66,7 +89,11 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// host=/var/run/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write /// ``` /// /// ```not_rust @@ -76,7 +103,7 @@ use tokio_postgres::{Error, Socket}; /// # Url /// /// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, -/// and the format accept query parameters for all of the key-value pairs described in the section above. Multiple +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple /// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, /// as the path component of the URL specifies the database name. /// @@ -87,7 +114,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql://user:password@%2Fvar%2Frun%2Fpostgresql/mydb?connect_timeout=10 +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 /// ``` /// /// ```not_rust @@ -95,7 +122,7 @@ use tokio_postgres::{Error, Socket}; /// ``` /// /// ```not_rust -/// postgresql:///mydb?user=user&host=/var/run/postgresql +/// postgresql:///mydb?user=user&host=/var/lib/postgresql /// ``` #[derive(Clone)] pub struct Config { @@ -125,7 +152,7 @@ impl Config { /// Sets the user to authenticate with. /// - /// Required. + /// If the user is not set, then this defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { self.config.user(user); self @@ -207,6 +234,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { self.config.host(host); self @@ -217,6 +245,11 @@ impl Config { self.config.get_hosts() } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.config.get_hostaddrs() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -229,6 +262,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.config.hostaddr(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -359,6 +401,19 @@ impl Config { self.config.get_channel_binding() } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.config.load_balance_hosts(load_balance_hosts); + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.config.get_load_balance_hosts() + } + /// Sets the notice callback. /// /// This callback will be invoked with the contents of every diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 3345a1d43..75448d130 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,6 +1,25 @@ # Change Log -## v0.7.8 +## v0.7.10 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + +## v0.7.9 - 2023-08-19 + +## Fixed + +* Fixed builds on OpenBSD. + +## Added + +* Added the `js` feature for WASM support. +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + +## v0.7.8 - 2023-05-27 ## Added diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index e6c7b4f7a..fefaddd0b 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.8" +version = "0.7.10" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -40,6 +40,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"] with-uuid-1 = ["postgres-types/with-uuid-1"] with-time-0_2 = ["postgres-types/with-time-0_2"] with-time-0_3 = ["postgres-types/with-time-0_3"] +js = ["postgres-protocol/js"] [dependencies] async-trait = "0.1" @@ -53,15 +54,19 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.11" -postgres-protocol = { version = "0.6.4" } -postgres-types = { version = "0.2.4" } -socket2 = { version = "0.5", features = ["all"] } +postgres-protocol = { version = "0.6.6" } +postgres-types = { version = "0.2.5" } tokio = { version = "0.2.2", package = "madsim-tokio", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } +rand = "0.8.5" +whoami = "1.4.1" + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +socket2 = { version = "0.5", features = ["all"] } [dev-dependencies] futures-executor = "0.3" -criterion = "0.4" +criterion = "0.5" env_logger = "0.10" tokio = { version = "0.2", package = "madsim-tokio", features = [ "macros", diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs index d869b5824..078d4b8b6 100644 --- a/tokio-postgres/src/cancel_query.rs +++ b/tokio-postgres/src/cancel_query.rs @@ -1,5 +1,5 @@ use crate::client::SocketConfig; -use crate::config::{Host, SslMode}; +use crate::config::SslMode; use crate::tls::MakeTlsConnect; use crate::{cancel_query_raw, connect_socket, Error, Socket}; use std::io; @@ -24,18 +24,13 @@ where } }; - let hostname = match &config.host { - Host::Tcp(host) => &**host, - // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter - #[cfg(unix)] - Host::Unix(_) => "", - }; let tls = tls - .make_tls_connect(hostname) + .make_tls_connect(config.hostname.as_deref().unwrap_or("")) .map_err(|e| Error::tls(e.into()))?; + let has_hostname = config.hostname.is_some(); let socket = connect_socket::connect_socket( - &config.host, + &config.addr, config.port, config.connect_timeout, config.tcp_user_timeout, @@ -43,5 +38,6 @@ where ) .await?; - cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key) + .await } diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index c89dc581f..41aafe7d9 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -9,6 +9,7 @@ pub async fn cancel_query_raw( stream: S, mode: SslMode, tls: T, + has_hostname: bool, process_id: i32, secret_key: i32, ) -> Result<(), Error> @@ -16,7 +17,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; + let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?; let mut buf = BytesMut::new(); frontend::cancel_request(process_id, secret_key, &mut buf); diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs index d048a3c82..c925ce0ca 100644 --- a/tokio-postgres/src/cancel_token.rs +++ b/tokio-postgres/src/cancel_token.rs @@ -55,6 +55,7 @@ impl CancelToken { stream, self.ssl_mode, tls, + true, self.process_id, self.secret_key, ) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 8b7df4e87..427a05049 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,4 @@ use crate::codec::{BackendMessages, FrontendMessage}; -#[cfg(feature = "runtime")] -use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; @@ -27,6 +25,10 @@ use postgres_protocol::message::{backend::Message, frontend}; use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; +#[cfg(feature = "runtime")] +use std::net::IpAddr; +#[cfg(feature = "runtime")] +use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -153,13 +155,22 @@ impl InnerClient { #[cfg(feature = "runtime")] #[derive(Clone)] pub(crate) struct SocketConfig { - pub host: Host, + pub addr: Addr, + pub hostname: Option, pub port: u16, pub connect_timeout: Option, pub tcp_user_timeout: Option, pub keepalive: Option, } +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub(crate) enum Addr { + Tcp(IpAddr), + #[cfg(unix)] + Unix(PathBuf), +} + /// An asynchronous PostgreSQL client. /// /// The client is one half of what is returned when a connection is established. Users interact with the database diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a8aa7a9f5..b178eac80 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -3,6 +3,7 @@ #[cfg(feature = "runtime")] use crate::connect::connect; use crate::connect_raw::connect_raw; +#[cfg(not(target_arch = "wasm32"))] use crate::keepalive::KeepaliveConfig; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; @@ -13,6 +14,8 @@ use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; +use std::net::IpAddr; +use std::ops::Deref; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] @@ -57,6 +60,16 @@ pub enum ChannelBinding { Require, } +/// Load balancing configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum LoadBalanceHosts { + /// Make connection attempts to hosts in the order provided. + Disable, + /// Make connection attempts to hosts in a random order. + Random, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -80,7 +93,7 @@ pub enum Host { /// /// ## Keys /// -/// * `user` - The username to authenticate with. Required. +/// * `user` - The username to authenticate with. Defaults to the user executing this process. /// * `password` - The password to authenticate with. /// * `dbname` - The name of the database to connect to. Defaults to the username. /// * `options` - Command line options used to configure the server. @@ -91,6 +104,19 @@ pub enum Host { /// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts /// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting /// with the `connect` method. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for TLS certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. /// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be /// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if /// omitted or the empty string. @@ -113,6 +139,12 @@ pub enum Host { /// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel /// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. /// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. /// /// ## Examples /// @@ -125,6 +157,10 @@ pub enum Host { /// ``` /// /// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust /// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write /// ``` /// @@ -161,13 +197,16 @@ pub struct Config { pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, + pub(crate) hostaddr: Vec, pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) tcp_user_timeout: Option, pub(crate) keepalives: bool, + #[cfg(not(target_arch = "wasm32"))] pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) load_balance_hosts: LoadBalanceHosts, } impl Default for Config { @@ -179,11 +218,6 @@ impl Default for Config { impl Config { /// Creates a new configuration. pub fn new() -> Config { - let keepalive_config = KeepaliveConfig { - idle: Duration::from_secs(2 * 60 * 60), - interval: None, - retries: None, - }; Config { user: None, password: None, @@ -192,19 +226,26 @@ impl Config { application_name: None, ssl_mode: SslMode::Prefer, host: vec![], + hostaddr: vec![], port: vec![], connect_timeout: None, tcp_user_timeout: None, keepalives: true, - keepalive_config, + #[cfg(not(target_arch = "wasm32"))] + keepalive_config: KeepaliveConfig { + idle: Duration::from_secs(2 * 60 * 60), + interval: None, + retries: None, + }, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + load_balance_hosts: LoadBalanceHosts::Disable, } } /// Sets the user to authenticate with. /// - /// Required. + /// Defaults to the user executing this process. pub fn user(&mut self, user: &str) -> &mut Config { self.user = Some(user.to_string()); self @@ -286,6 +327,7 @@ impl Config { /// /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. pub fn host(&mut self, host: &str) -> &mut Config { #[cfg(unix)] { @@ -303,6 +345,11 @@ impl Config { &self.host } + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.hostaddr.deref() + } + /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. @@ -315,6 +362,15 @@ impl Config { self } + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.hostaddr.push(hostaddr); + self + } + /// Adds a port to the configuration. /// /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which @@ -377,6 +433,7 @@ impl Config { /// Sets the amount of idle time before a keepalive packet is sent on the connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config { self.keepalive_config.idle = keepalives_idle; self @@ -384,6 +441,7 @@ impl Config { /// Gets the configured amount of idle time before a keepalive packet will /// be sent on the connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_idle(&self) -> Duration { self.keepalive_config.idle } @@ -392,12 +450,14 @@ impl Config { /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config { self.keepalive_config.interval = Some(keepalives_interval); self } /// Gets the time interval between TCP keepalive probes. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_interval(&self) -> Option { self.keepalive_config.interval } @@ -405,12 +465,14 @@ impl Config { /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection. /// /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config { self.keepalive_config.retries = Some(keepalives_retries); self } /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + #[cfg(not(target_arch = "wasm32"))] pub fn get_keepalives_retries(&self) -> Option { self.keepalive_config.retries } @@ -445,6 +507,19 @@ impl Config { self.channel_binding } + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.load_balance_hosts = load_balance_hosts; + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.load_balance_hosts + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -476,6 +551,14 @@ impl Config { self.host(host); } } + "hostaddr" => { + for hostaddr in value.split(',') { + let addr = hostaddr + .parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; + self.hostaddr(addr); + } + } "port" => { for port in value.split(',') { let port = if port.is_empty() { @@ -503,12 +586,14 @@ impl Config { self.tcp_user_timeout(Duration::from_secs(timeout as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives" => { let keepalives = value .parse::() .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?; self.keepalives(keepalives != 0); } + #[cfg(not(target_arch = "wasm32"))] "keepalives_idle" => { let keepalives_idle = value .parse::() @@ -517,6 +602,7 @@ impl Config { self.keepalives_idle(Duration::from_secs(keepalives_idle as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_interval" => { let keepalives_interval = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_interval"))) @@ -525,6 +611,7 @@ impl Config { self.keepalives_interval(Duration::from_secs(keepalives_interval as u64)); } } + #[cfg(not(target_arch = "wasm32"))] "keepalives_retries" => { let keepalives_retries = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("keepalives_retries"))) @@ -556,6 +643,18 @@ impl Config { }; self.channel_binding(channel_binding); } + "load_balance_hosts" => { + let load_balance_hosts = match value { + "disable" => LoadBalanceHosts::Disable, + "random" => LoadBalanceHosts::Random, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "load_balance_hosts", + )))) + } + }; + self.load_balance_hosts(load_balance_hosts); + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -589,7 +688,7 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, tls, self).await + connect_raw(stream, tls, true, self).await } } @@ -614,7 +713,8 @@ impl fmt::Debug for Config { } } - f.debug_struct("Config") + let mut config_dbg = &mut f.debug_struct("Config"); + config_dbg = config_dbg .field("user", &self.user) .field("password", &self.password.as_ref().map(|_| Redaction {})) .field("dbname", &self.dbname) @@ -622,13 +722,21 @@ impl fmt::Debug for Config { .field("application_name", &self.application_name) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) + .field("hostaddr", &self.hostaddr) .field("port", &self.port) .field("connect_timeout", &self.connect_timeout) .field("tcp_user_timeout", &self.tcp_user_timeout) - .field("keepalives", &self.keepalives) - .field("keepalives_idle", &self.keepalive_config.idle) - .field("keepalives_interval", &self.keepalive_config.interval) - .field("keepalives_retries", &self.keepalive_config.retries) + .field("keepalives", &self.keepalives); + + #[cfg(not(target_arch = "wasm32"))] + { + config_dbg = config_dbg + .field("keepalives_idle", &self.keepalive_config.idle) + .field("keepalives_interval", &self.keepalive_config.interval) + .field("keepalives_retries", &self.keepalive_config.retries); + } + + config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) .finish() @@ -1005,3 +1113,41 @@ impl<'a> UrlParser<'a> { .map_err(|e| Error::config_parse(e.into())) } } + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + + use crate::{config::Host, Config}; + + #[test] + fn test_simple_parsing() { + let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; + let config = s.parse::().unwrap(); + assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!(Some("postgres"), config.get_dbname()); + assert_eq!( + [ + Host::Tcp("host1".to_string()), + Host::Tcp("host2".to_string()) + ], + config.get_hosts(), + ); + + assert_eq!( + [ + "127.0.0.1".parse::().unwrap(), + "127.0.0.2".parse::().unwrap() + ], + config.get_hostaddrs(), + ); + + assert_eq!(1, 1); + } + + #[test] + fn test_invalid_hostaddr_parsing() { + let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; + s.parse::().err().unwrap(); + } +} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ed7ecac66..ca57b9cdd 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -1,12 +1,14 @@ -use crate::client::SocketConfig; -use crate::config::{Host, TargetSessionAttrs}; +use crate::client::{Addr, SocketConfig}; +use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; -use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::tls::MakeTlsConnect; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; -use std::io; +use rand::seq::SliceRandom; use std::task::Poll; +use std::{cmp, io}; +use tokio::net; pub async fn connect( mut tls: T, @@ -15,16 +17,40 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() { - return Err(Error::config("host missing".into())); + if config.host.is_empty() && config.hostaddr.is_empty() { + return Err(Error::config("both host and hostaddr are missing".into())); } - if config.port.len() > 1 && config.port.len() != config.host.len() { + if !config.host.is_empty() + && !config.hostaddr.is_empty() + && config.host.len() != config.hostaddr.len() + { + let msg = format!( + "number of hosts ({}) is different from number of hostaddrs ({})", + config.host.len(), + config.hostaddr.len(), + ); + return Err(Error::config(msg.into())); + } + + // At this point, either one of the following two scenarios could happen: + // (1) either config.host or config.hostaddr must be empty; + // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. + let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); + + if config.port.len() > 1 && config.port.len() != num_hosts { return Err(Error::config("invalid number of ports".into())); } + let mut indices = (0..num_hosts).collect::>(); + if config.load_balance_hosts == LoadBalanceHosts::Random { + indices.shuffle(&mut rand::thread_rng()); + } + let mut error = None; - for (i, host) in config.host.iter().enumerate() { + for i in indices { + let host = config.host.get(i); + let hostaddr = config.hostaddr.get(i); let port = config .port .get(i) @@ -32,18 +58,23 @@ where .copied() .unwrap_or(5432); + // The value of host is used as the hostname for TLS validation, let hostname = match host { - Host::Tcp(host) => host.as_str(), + Some(Host::Tcp(host)) => Some(host.clone()), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] - Host::Unix(_) => "", + Some(Host::Unix(_)) => None, + None => None, }; - let tls = tls - .make_tls_connect(hostname) - .map_err(|e| Error::tls(e.into()))?; + // Try to use the value of hostaddr to establish the TCP connection, + // fallback to host if hostaddr is not present. + let addr = match hostaddr { + Some(ipaddr) => Host::Tcp(ipaddr.to_string()), + None => host.cloned().unwrap(), + }; - match connect_once(host, port, tls, config).await { + match connect_host(addr, hostname, port, &mut tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } @@ -52,17 +83,66 @@ where Err(error.unwrap()) } +async fn connect_host( + host: Host, + hostname: Option, + port: u16, + tls: &mut T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + match host { + Host::Tcp(host) => { + let mut addrs = net::lookup_host((&*host, port)) + .await + .map_err(Error::connect)? + .collect::>(); + + if config.load_balance_hosts == LoadBalanceHosts::Random { + addrs.shuffle(&mut rand::thread_rng()); + } + + let mut last_err = None; + for addr in addrs { + match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => { + last_err = Some(e); + continue; + } + }; + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + #[cfg(unix)] + Host::Unix(path) => { + connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await + } + } +} + async fn connect_once( - host: &Host, + addr: Addr, + hostname: Option<&str>, port: u16, - tls: T, + tls: &mut T, config: &Config, ) -> Result<(Client, Connection), Error> where - T: TlsConnect, + T: MakeTlsConnect, { let socket = connect_socket( - host, + &addr, port, config.connect_timeout, config.tcp_user_timeout, @@ -73,7 +153,12 @@ where }, ) .await?; - let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + + let tls = tls + .make_tls_connect(hostname.unwrap_or("")) + .map_err(|e| Error::tls(e.into()))?; + let has_hostname = hostname.is_some(); + let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); @@ -116,7 +201,8 @@ where } client.set_socket_config(SocketConfig { - host: host.clone(), + addr, + hostname: hostname.map(|s| s.to_string()), port, connect_timeout: config.connect_timeout, tcp_user_timeout: config.tcp_user_timeout, diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index d97636221..19be9eb01 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -13,6 +13,7 @@ use postgres_protocol::authentication::sasl; use postgres_protocol::authentication::sasl::ScramSha256; use postgres_protocol::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol::message::frontend; +use std::borrow::Cow; use std::collections::{HashMap, VecDeque}; use std::io; use std::pin::Pin; @@ -81,13 +82,14 @@ where pub async fn connect_raw( stream: S, tls: T, + has_hostname: bool, config: &Config, ) -> Result<(Client, Connection), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let stream = connect_tls(stream, config.ssl_mode, tls).await?; + let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), @@ -95,8 +97,13 @@ where delayed: VecDeque::new(), }; - startup(&mut stream, config).await?; - authenticate(&mut stream, config).await?; + let user = config + .user + .as_deref() + .map_or_else(|| Cow::Owned(whoami::username()), Cow::Borrowed); + + startup(&mut stream, config, &user).await?; + authenticate(&mut stream, config, &user).await?; let (process_id, secret_key, parameters) = read_info(&mut stream).await?; let (sender, receiver) = mpsc::unbounded(); @@ -106,15 +113,17 @@ where Ok((client, connection)) } -async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn startup( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { let mut params = vec![("client_encoding", "UTF8")]; - if let Some(user) = &config.user { - params.push(("user", &**user)); - } + params.push(("user", user)); if let Some(dbname) = &config.dbname { params.push(("database", &**dbname)); } @@ -134,7 +143,11 @@ where .map_err(Error::io) } -async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +async fn authenticate( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, @@ -157,10 +170,6 @@ where Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; - let user = config - .user - .as_ref() - .ok_or_else(|| Error::config("user missing".into()))?; let pass = config .password .as_ref() diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index ae2359e74..02dab3bd9 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -1,72 +1,53 @@ -use crate::config::Host; +use crate::client::Addr; use crate::keepalive::KeepaliveConfig; use crate::{Error, Socket}; use socket2::{SockRef, TcpKeepalive}; use std::future::Future; use std::io; use std::time::Duration; +use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; -use tokio::net::{self, TcpStream}; use tokio::time; pub(crate) async fn connect_socket( - host: &Host, + addr: &Addr, port: u16, connect_timeout: Option, - tcp_user_timeout: Option, + #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option< + Duration, + >, keepalive_config: Option<&KeepaliveConfig>, ) -> Result { - match host { - Host::Tcp(host) => { - let addrs = net::lookup_host((&**host, port)) - .await - .map_err(Error::connect)?; + match addr { + Addr::Tcp(ip) => { + let stream = + connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?; - let mut last_err = None; + stream.set_nodelay(true).map_err(Error::connect)?; - for addr in addrs { - let stream = - match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { - Ok(stream) => stream, - Err(e) => { - last_err = Some(e); - continue; - } - }; - - stream.set_nodelay(true).map_err(Error::connect)?; - - #[cfg(not(madsim))] - let sock_ref = SockRef::from(&stream); - #[cfg(target_os = "linux")] - #[cfg(not(madsim))] - { - sock_ref - .set_tcp_user_timeout(tcp_user_timeout) - .map_err(Error::connect)?; - } - - #[cfg(not(madsim))] - if let Some(keepalive_config) = keepalive_config { - sock_ref - .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) - .map_err(Error::connect)?; - } + #[cfg(not(madsim))] + let sock_ref = SockRef::from(&stream); + #[cfg(target_os = "linux")] + #[cfg(not(madsim))] + { + sock_ref + .set_tcp_user_timeout(tcp_user_timeout) + .map_err(Error::connect)?; + } - return Ok(Socket::new_tcp(stream)); + #[cfg(not(madsim))] + if let Some(keepalive_config) = keepalive_config { + sock_ref + .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) + .map_err(Error::connect)?; } - Err(last_err.unwrap_or_else(|| { - Error::connect(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })) + Ok(Socket::new_tcp(stream)) } #[cfg(unix)] - Host::Unix(path) => { - let path = path.join(format!(".s.PGSQL.{}", port)); + Addr::Unix(dir) => { + let path = dir.join(format!(".s.PGSQL.{}", port)); let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?; Ok(Socket::new_unix(socket)) } diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index 5ef21ac5c..2b1229125 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -11,6 +11,7 @@ pub async fn connect_tls( mut stream: S, mode: SslMode, tls: T, + has_hostname: bool, ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, @@ -39,6 +40,10 @@ where } } + if !has_hostname { + return Err(Error::tls("no hostname provided for TLS handshake".into())); + } + let stream = tls .connect(stream) .await diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs index 74f453985..c409eb0ea 100644 --- a/tokio-postgres/src/keepalive.rs +++ b/tokio-postgres/src/keepalive.rs @@ -12,12 +12,17 @@ impl From<&KeepaliveConfig> for TcpKeepalive { fn from(keepalive_config: &KeepaliveConfig) -> Self { let mut tcp_keepalive = Self::new().with_time(keepalive_config.idle); - #[cfg(not(any(target_os = "redox", target_os = "solaris")))] + #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "openbsd")))] if let Some(interval) = keepalive_config.interval { tcp_keepalive = tcp_keepalive.with_interval(interval); } - #[cfg(not(any(target_os = "redox", target_os = "solaris", target_os = "windows")))] + #[cfg(not(any( + target_os = "redox", + target_os = "solaris", + target_os = "windows", + target_os = "openbsd" + )))] if let Some(retries) = keepalive_config.retries { tcp_keepalive = tcp_keepalive.with_retries(retries); } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..ff8e93ddc 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -116,7 +116,6 @@ //! | `with-uuid-1` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 1.0 | no | //! | `with-time-0_2` | Enable support for the 0.2 version of the `time` crate. | [time](https://crates.io/crates/time/0.2.0) 0.2 | no | //! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | -#![doc(html_root_url = "https://docs.rs/tokio-postgres/0.7")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] pub use crate::cancel_token::CancelToken; @@ -163,6 +162,7 @@ mod copy_in; mod copy_out; pub mod error; mod generic_client; +#[cfg(not(target_arch = "wasm32"))] mod keepalive; mod maybe_tls_stream; mod portal; diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 67b4ead8a..86c1f0701 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -66,6 +66,58 @@ async fn target_session_attrs_err() { .unwrap(); } +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + #[tokio::test] async fn cancel_query() { let client = connect("host=localhost port=5433 user=postgres").await; diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 452d149fe..f1a44da08 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -739,3 +739,25 @@ async fn ltxtquery_any() { ) .await; } + +#[tokio::test] +async fn oidvector() { + test_type( + "oidvector", + // NB: postgres does not support empty oidarrays! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0u32, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +} + +#[tokio::test] +async fn int2vector() { + test_type( + "int2vector", + // NB: postgres does not support empty int2vectors! All empty arrays are normalized to zero dimensions, but the + // oidvectorrecv function requires exactly one dimension. + &[(Some(vec![0i16, 1, 2]), "ARRAY[0,1,2]"), (None, "NULL")], + ) + .await; +}