From cab9f38b834407acdab6b73b73118dea313e5c06 Mon Sep 17 00:00:00 2001 From: Raminder Singh Date: Fri, 30 Aug 2024 17:49:08 +0530 Subject: [PATCH] add create, read and delete api for publications --- api/src/db/publications.rs | 144 +++++++++++++++-- api/src/db/tables.rs | 4 +- api/src/routes/sources/publications.rs | 205 +++++++++++++++++++++++++ api/src/routes/sources/tables.rs | 2 +- api/src/startup.rs | 14 +- cli/src/api_client.rs | 55 ------- 6 files changed, 355 insertions(+), 69 deletions(-) diff --git a/api/src/db/publications.rs b/api/src/db/publications.rs index d92acb9..ba2cf74 100644 --- a/api/src/db/publications.rs +++ b/api/src/db/publications.rs @@ -1,6 +1,8 @@ -use std::borrow::Cow; +use std::{borrow::Cow, collections::HashMap}; -use sqlx::{postgres::PgConnectOptions, Connection, Executor, PgConnection}; +use serde::Serialize; +use sqlx::{postgres::PgConnectOptions, Connection, Executor, PgConnection, Row}; +use tracing::error; use super::tables::Table; @@ -23,25 +25,54 @@ fn quote_identifier_alloc(identifier: &str) -> String { quoted_identifier } -pub async fn create_publication_on_source( - publication_name: &str, - tables: &[Table], +pub fn quote_literal(literal: &str) -> String { + let mut quoted_literal = String::with_capacity(literal.len() + 2); + + if literal.find('\\').is_some() { + quoted_literal.push('E'); + } + + quoted_literal.push('\''); + + for char in literal.chars() { + if char == '\'' { + quoted_literal.push('\''); + } else if char == '\\' { + quoted_literal.push('\\'); + } + + quoted_literal.push(char); + } + + quoted_literal.push('\''); + + quoted_literal +} + +#[derive(Serialize)] +pub struct Publication { + pub name: String, + pub tables: Vec, +} + +pub async fn create_publication( + publication: &Publication, options: &PgConnectOptions, ) -> Result<(), sqlx::Error> { let mut query = String::new(); - let quoted_publication_name = quote_identifier(publication_name); + let quoted_publication_name = quote_identifier(&publication.name); query.push_str("create publication "); query.push_str("ed_publication_name); - query.push_str(" table only "); + query.push_str(" for table only "); - for (i, table) in tables.iter().enumerate() { + for (i, table) in publication.tables.iter().enumerate() { let quoted_schema = quote_identifier(&table.schema); let quoted_name = quote_identifier(&table.name); query.push_str("ed_schema); query.push('.'); query.push_str("ed_name); - if i < tables.len() - 1 { + if i < publication.tables.len() - 1 { query.push(',') } } @@ -51,3 +82,98 @@ pub async fn create_publication_on_source( Ok(()) } + +pub async fn drop_publication( + publication_name: &str, + options: &PgConnectOptions, +) -> Result<(), sqlx::Error> { + let mut query = String::new(); + query.push_str("drop publication if exists "); + let quoted_publication_name = quote_identifier(publication_name); + query.push_str("ed_publication_name); + + let mut connection = PgConnection::connect_with(options).await?; + connection.execute(query.as_str()).await?; + + Ok(()) +} + +pub async fn read_publication( + publication_name: &str, + options: &PgConnectOptions, +) -> Result, sqlx::Error> { + let mut query = String::new(); + query.push_str( + r#" + select p.pubname, pt.schemaname, pt.tablename from pg_publication p + join pg_publication_tables pt on p.pubname = pt.pubname + where + p.puballtables = false + and p.pubinsert = true + and p.pubupdate = true + and p.pubdelete = true + and p.pubtruncate = true + and p.pubname = + "#, + ); + + let quoted_publication_name = quote_literal(publication_name); + query.push_str("ed_publication_name); + + error!("QUERY: {query}"); + + let mut connection = PgConnection::connect_with(options).await?; + + let mut tables = vec![]; + let mut name: Option = None; + + for row in connection.fetch_all(query.as_str()).await? { + let pub_name: String = row.get("pubname"); + if let Some(ref name) = name { + assert_eq!(name.as_str(), pub_name); + } else { + name = Some(pub_name); + } + let schema = row.get("schemaname"); + let name = row.get("tablename"); + tables.push(Table { schema, name }); + } + + let publication = name.map(|name| Publication { name, tables }); + + Ok(publication) +} + +pub async fn read_all_publications( + options: &PgConnectOptions, +) -> Result, sqlx::Error> { + let query = r#" + select p.pubname, pt.schemaname, pt.tablename from pg_publication p + join pg_publication_tables pt on p.pubname = pt.pubname + where + p.puballtables = false + and p.pubinsert = true + and p.pubupdate = true + and p.pubdelete = true + and p.pubtruncate = true; + "#; + + let mut connection = PgConnection::connect_with(options).await?; + + let mut pub_name_to_tables: HashMap> = HashMap::new(); + + for row in connection.fetch_all(query).await? { + let pub_name: String = row.get("pubname"); + let schema = row.get("schemaname"); + let name = row.get("tablename"); + let tables = pub_name_to_tables.entry(pub_name).or_default(); + tables.push(Table { schema, name }); + } + + let publications = pub_name_to_tables + .into_iter() + .map(|(name, tables)| Publication { name, tables }) + .collect(); + + Ok(publications) +} diff --git a/api/src/db/tables.rs b/api/src/db/tables.rs index 6ba1a37..07da78f 100644 --- a/api/src/db/tables.rs +++ b/api/src/db/tables.rs @@ -1,7 +1,7 @@ -use serde::Serialize; +use serde::{Deserialize, Serialize}; use sqlx::{postgres::PgConnectOptions, Connection, Executor, PgConnection, Row}; -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub struct Table { pub schema: String, pub name: String, diff --git a/api/src/routes/sources/publications.rs b/api/src/routes/sources/publications.rs index 8b13789..12eda82 100644 --- a/api/src/routes/sources/publications.rs +++ b/api/src/routes/sources/publications.rs @@ -1 +1,206 @@ +use actix_web::{ + delete, get, + http::{header::ContentType, StatusCode}, + post, + web::{Data, Json, Path}, + HttpRequest, HttpResponse, Responder, ResponseError, +}; +use serde::Deserialize; +use sqlx::PgPool; +use thiserror::Error; +use crate::{ + db::{self, publications::Publication, sources::SourceConfig, tables::Table}, + routes::ErrorMessage, +}; + +#[derive(Debug, Error)] +enum PublicationError { + #[error("database error: {0}")] + DatabaseError(#[from] sqlx::Error), + + #[error("source with id {0} not found")] + SourceNotFound(i64), + + #[error("publication with name {0} not found")] + PublicationNotFound(String), + + #[error("tenant id missing in request")] + TenantIdMissing, + + #[error("tenant id ill formed in request")] + TenantIdIllFormed, + + #[error("invalid source config")] + InvalidConfig(#[from] serde_json::Error), +} + +impl PublicationError { + fn to_message(&self) -> String { + match self { + // Do not expose internal database details in error messages + PublicationError::DatabaseError(_) => "internal server error".to_string(), + // Every other message is ok, as they do not divulge sensitive information + e => e.to_string(), + } + } +} + +impl ResponseError for PublicationError { + fn status_code(&self) -> StatusCode { + match self { + PublicationError::DatabaseError(_) | PublicationError::InvalidConfig(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } + PublicationError::SourceNotFound(_) | PublicationError::PublicationNotFound(_) => { + StatusCode::NOT_FOUND + } + PublicationError::TenantIdMissing | PublicationError::TenantIdIllFormed => { + StatusCode::BAD_REQUEST + } + } + } + + fn error_response(&self) -> HttpResponse { + let error_message = ErrorMessage { + error: self.to_message(), + }; + let body = + serde_json::to_string(&error_message).expect("failed to serialize error message"); + HttpResponse::build(self.status_code()) + .insert_header(ContentType::json()) + .body(body) + } +} + +// TODO: read tenant_id from a jwt +fn extract_tenant_id(req: &HttpRequest) -> Result { + let headers = req.headers(); + let tenant_id = headers + .get("tenant_id") + .ok_or(PublicationError::TenantIdMissing)?; + let tenant_id = tenant_id + .to_str() + .map_err(|_| PublicationError::TenantIdIllFormed)?; + let tenant_id: i64 = tenant_id + .parse() + .map_err(|_| PublicationError::TenantIdIllFormed)?; + Ok(tenant_id) +} + +#[derive(Deserialize)] +struct CreatePublicationRequest { + name: String, + tables: Vec
, +} + +#[derive(Deserialize)] +struct DeletePublicationRequest { + name: String, +} + +#[post("/sources/{source_id}/publications")] +pub async fn create_publication( + req: HttpRequest, + pool: Data, + source_id: Path, + publication: Json, +) -> Result { + let tenant_id = extract_tenant_id(&req)?; + let source_id = source_id.into_inner(); + + let config = db::sources::read_source(&pool, tenant_id, source_id) + .await? + .map(|s| { + let config: SourceConfig = serde_json::from_value(s.config)?; + Ok::(config) + }) + .transpose()? + .ok_or(PublicationError::SourceNotFound(source_id))?; + + let options = config.connect_options(); + let publication = publication.0; + let publication = Publication { + name: publication.name, + tables: publication.tables, + }; + db::publications::create_publication(&publication, &options).await?; + + Ok(HttpResponse::Ok().finish()) +} + +#[get("/sources/{source_id}/publications/{publication_name}")] +pub async fn read_publication( + req: HttpRequest, + pool: Data, + source_id_and_pub_name: Path<(i64, String)>, +) -> Result { + let tenant_id = extract_tenant_id(&req)?; + let (source_id, publication_name) = source_id_and_pub_name.into_inner(); + + let config = db::sources::read_source(&pool, tenant_id, source_id) + .await? + .map(|s| { + let config: SourceConfig = serde_json::from_value(s.config)?; + Ok::(config) + }) + .transpose()? + .ok_or(PublicationError::SourceNotFound(source_id))?; + + let options = config.connect_options(); + let publications = db::publications::read_publication(&publication_name, &options) + .await? + .ok_or(PublicationError::PublicationNotFound(publication_name))?; + + Ok(Json(publications)) +} + +#[get("/sources/{source_id}/publications")] +pub async fn read_all_publications( + req: HttpRequest, + pool: Data, + source_id: Path, +) -> Result { + let tenant_id = extract_tenant_id(&req)?; + let source_id = source_id.into_inner(); + + let config = db::sources::read_source(&pool, tenant_id, source_id) + .await? + .map(|s| { + let config: SourceConfig = serde_json::from_value(s.config)?; + Ok::(config) + }) + .transpose()? + .ok_or(PublicationError::SourceNotFound(source_id))?; + + let options = config.connect_options(); + let publications = db::publications::read_all_publications(&options).await?; + + Ok(Json(publications)) +} + +#[delete("/sources/{source_id}/publications")] +pub async fn delete_publication( + req: HttpRequest, + pool: Data, + source_id: Path, + publication: Json, +) -> Result { + let tenant_id = extract_tenant_id(&req)?; + let source_id = source_id.into_inner(); + + let config = db::sources::read_source(&pool, tenant_id, source_id) + .await? + .map(|s| { + let config: SourceConfig = serde_json::from_value(s.config)?; + Ok::(config) + }) + .transpose()? + .ok_or(PublicationError::SourceNotFound(source_id))?; + + let options = config.connect_options(); + let publication = publication.0; + db::publications::drop_publication(&publication.name, &options).await?; + + Ok(HttpResponse::Ok().finish()) +} diff --git a/api/src/routes/sources/tables.rs b/api/src/routes/sources/tables.rs index 73d0cbf..4c9ddf6 100644 --- a/api/src/routes/sources/tables.rs +++ b/api/src/routes/sources/tables.rs @@ -79,7 +79,7 @@ fn extract_tenant_id(req: &HttpRequest) -> Result { Ok(tenant_id) } -#[get("/sources/{source_id}/table_names")] +#[get("/sources/{source_id}/tables")] pub async fn read_table_names( req: HttpRequest, pool: Data, diff --git a/api/src/startup.rs b/api/src/startup.rs index 19c1ee9..9f0acc0 100644 --- a/api/src/startup.rs +++ b/api/src/startup.rs @@ -13,7 +13,12 @@ use crate::{ }, sinks::{create_sink, delete_sink, read_all_sinks, read_sink, update_sink}, sources::{ - create_source, delete_source, read_all_sources, read_source, tables::read_table_names, + create_source, delete_source, + publications::{ + create_publication, delete_publication, read_all_publications, read_publication, + }, + read_all_sources, read_source, + tables::read_table_names, update_source, }, tenants::{create_tenant, delete_tenant, read_all_tenants, read_tenant, update_tenant}, @@ -86,7 +91,12 @@ pub async fn run(listener: TcpListener, connection_pool: PgPool) -> Result, -} - -impl Display for PublicationConfig { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "table_names: {:?}", self.table_names) - } -} - #[derive(Serialize)] pub struct CreateTenantRequest { pub name: String, @@ -275,50 +264,6 @@ pub struct UpdatePipelineRequest { pub config: PipelineConfig, } -#[derive(Serialize)] -pub struct CreatePublicationRequest { - pub source_id: i64, - pub name: String, - pub config: PublicationConfig, -} - -#[derive(Deserialize)] -pub struct CreatePublicationResponse { - pub id: i64, -} - -impl Display for CreatePublicationResponse { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "id: {}", self.id) - } -} - -#[derive(Deserialize)] -pub struct PublicationResponse { - pub id: i64, - pub tenant_id: i64, - pub source_id: i64, - pub name: String, - pub config: PublicationConfig, -} - -impl Display for PublicationResponse { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "tenant_id: {}, id: {}, source_id: {}, name: {}, config: {}", - self.tenant_id, self.id, self.source_id, self.name, self.config - ) - } -} - -#[derive(Serialize)] -pub struct UpdatePublicationRequest { - pub source_id: i64, - pub name: String, - pub config: PublicationConfig, -} - #[derive(Debug, Error)] pub enum ApiClientError { #[error("reqwest error: {0}")]