Skip to content

Commit

Permalink
proxy: merge AuthError and AuthErrorImpl (#9418)
Browse files Browse the repository at this point in the history
Since GetAuthInfoError now boxes the ControlPlaneError message the
variant is not big anymore and AuthError is 32 bytes.
  • Loading branch information
cloneable authored Oct 16, 2024
1 parent 8a114e3 commit ed69473
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 49 deletions.
10 changes: 5 additions & 5 deletions proxy/src/auth/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;

use super::backend::ComputeCredentialKeys;
use super::{AuthErrorImpl, PasswordHackPayload};
use super::{AuthError, PasswordHackPayload};
use crate::config::TlsServerEndPoint;
use crate::context::RequestMonitoring;
use crate::control_plane::AuthSecret;
Expand Down Expand Up @@ -117,14 +117,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
.ok_or(AuthError::MalformedPassword("missing terminator"))?;

let payload = PasswordHackPayload::parse(password)
// If we ended up here and the payload is malformed, it means that
// the user neither enabled SNI nor resorted to any other method
// for passing the project name we rely on. We should show them
// the most helpful error message and point to the documentation.
.ok_or(AuthErrorImpl::MissingEndpointName)?;
.ok_or(AuthError::MissingEndpointName)?;

Ok(payload)
}
Expand All @@ -136,7 +136,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
.ok_or(AuthError::MalformedPassword("missing terminator"))?;

let outcome = validate_password_and_exchange(
&self.state.pool,
Expand Down Expand Up @@ -166,7 +166,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;

// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
Expand Down
78 changes: 34 additions & 44 deletions proxy/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub(crate) type Result<T> = std::result::Result<T, AuthError>;

/// Common authentication error.
#[derive(Debug, Error)]
pub(crate) enum AuthErrorImpl {
pub(crate) enum AuthError {
#[error(transparent)]
Web(#[from] backend::WebAuthError),

Expand Down Expand Up @@ -78,80 +78,70 @@ pub(crate) enum AuthErrorImpl {
ConfirmationTimeout(humantime::Duration),
}

#[derive(Debug, Error)]
#[error(transparent)]
pub(crate) struct AuthError(Box<AuthErrorImpl>);

impl AuthError {
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
AuthErrorImpl::BadAuthMethod(name.into()).into()
AuthError::BadAuthMethod(name.into())
}

pub(crate) fn auth_failed(user: impl Into<Box<str>>) -> Self {
AuthErrorImpl::AuthFailed(user.into()).into()
AuthError::AuthFailed(user.into())
}

pub(crate) fn ip_address_not_allowed(ip: IpAddr) -> Self {
AuthErrorImpl::IpAddressNotAllowed(ip).into()
AuthError::IpAddressNotAllowed(ip)
}

pub(crate) fn too_many_connections() -> Self {
AuthErrorImpl::TooManyConnections.into()
AuthError::TooManyConnections
}

pub(crate) fn is_auth_failed(&self) -> bool {
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
matches!(self, AuthError::AuthFailed(_))
}

pub(crate) fn user_timeout(elapsed: Elapsed) -> Self {
AuthErrorImpl::UserTimeout(elapsed).into()
AuthError::UserTimeout(elapsed)
}

pub(crate) fn confirmation_timeout(timeout: humantime::Duration) -> Self {
AuthErrorImpl::ConfirmationTimeout(timeout).into()
}
}

impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
fn from(e: E) -> Self {
Self(Box::new(e.into()))
AuthError::ConfirmationTimeout(timeout)
}
}

impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
match self.0.as_ref() {
AuthErrorImpl::Web(e) => e.to_string_client(),
AuthErrorImpl::GetAuthInfo(e) => e.to_string_client(),
AuthErrorImpl::Sasl(e) => e.to_string_client(),
AuthErrorImpl::AuthFailed(_) => self.to_string(),
AuthErrorImpl::BadAuthMethod(_) => self.to_string(),
AuthErrorImpl::MalformedPassword(_) => self.to_string(),
AuthErrorImpl::MissingEndpointName => self.to_string(),
AuthErrorImpl::Io(_) => "Internal error".to_string(),
AuthErrorImpl::IpAddressNotAllowed(_) => self.to_string(),
AuthErrorImpl::TooManyConnections => self.to_string(),
AuthErrorImpl::UserTimeout(_) => self.to_string(),
AuthErrorImpl::ConfirmationTimeout(_) => self.to_string(),
match self {
Self::Web(e) => e.to_string_client(),
Self::GetAuthInfo(e) => e.to_string_client(),
Self::Sasl(e) => e.to_string_client(),
Self::AuthFailed(_) => self.to_string(),
Self::BadAuthMethod(_) => self.to_string(),
Self::MalformedPassword(_) => self.to_string(),
Self::MissingEndpointName => self.to_string(),
Self::Io(_) => "Internal error".to_string(),
Self::IpAddressNotAllowed(_) => self.to_string(),
Self::TooManyConnections => self.to_string(),
Self::UserTimeout(_) => self.to_string(),
Self::ConfirmationTimeout(_) => self.to_string(),
}
}
}

impl ReportableError for AuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self.0.as_ref() {
AuthErrorImpl::Web(e) => e.get_error_kind(),
AuthErrorImpl::GetAuthInfo(e) => e.get_error_kind(),
AuthErrorImpl::Sasl(e) => e.get_error_kind(),
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
AuthErrorImpl::Io(_) => crate::error::ErrorKind::ClientDisconnect,
AuthErrorImpl::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
AuthErrorImpl::UserTimeout(_) => crate::error::ErrorKind::User,
AuthErrorImpl::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
match self {
Self::Web(e) => e.get_error_kind(),
Self::GetAuthInfo(e) => e.get_error_kind(),
Self::Sasl(e) => e.get_error_kind(),
Self::AuthFailed(_) => crate::error::ErrorKind::User,
Self::BadAuthMethod(_) => crate::error::ErrorKind::User,
Self::MalformedPassword(_) => crate::error::ErrorKind::User,
Self::MissingEndpointName => crate::error::ErrorKind::User,
Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
Self::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
Self::UserTimeout(_) => crate::error::ErrorKind::User,
Self::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
}
}
}

1 comment on commit ed69473

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5290 tests run: 5073 passed, 0 failed, 217 skipped (full report)


Flaky tests (4)

Postgres 17

Postgres 16

Code coverage* (full report)

  • functions: 31.3% (7549 of 24105 functions)
  • lines: 49.1% (60416 of 122933 lines)

* collected from Rust tests only


The comment gets automatically updated with the latest test results
ed69473 at 2024-10-16T18:31:59.722Z :recycle:

Please sign in to comment.