diff --git a/wrappers/Cargo.lock b/wrappers/Cargo.lock index 5f8dc845..353a2a64 100644 --- a/wrappers/Cargo.lock +++ b/wrappers/Cargo.lock @@ -4243,6 +4243,8 @@ dependencies = [ "async-compression", "aws-config", "aws-sdk-s3", + "aws-smithy-http", + "aws-smithy-runtime-api", "chrono", "chrono-tz", "clickhouse-rs", diff --git a/wrappers/Cargo.toml b/wrappers/Cargo.toml index 3522bb37..4445c68f 100644 --- a/wrappers/Cargo.toml +++ b/wrappers/Cargo.toml @@ -22,7 +22,7 @@ clickhouse_fdw = ["clickhouse-rs", "chrono", "chrono-tz", "regex", "thiserror"] stripe_fdw = ["reqwest", "reqwest-middleware", "reqwest-retry", "serde_json", "thiserror", "url"] firebase_fdw = ["reqwest", "reqwest-middleware", "reqwest-retry", "serde_json", "yup-oauth2", "regex", "thiserror"] s3_fdw = [ - "reqwest", "reqwest-middleware", "reqwest-retry", "aws-config", "aws-sdk-s3", + "reqwest", "reqwest-middleware", "reqwest-retry", "aws-config", "aws-sdk-s3", "aws-smithy-http", "aws-smithy-runtime-api", "tokio", "tokio-util", "csv", "async-compression", "serde_json", "http", "parquet", "futures", "arrow-array", "chrono", "thiserror" ] @@ -64,6 +64,8 @@ url = { version = "2.3", optional = true } # for s3_fdw aws-config = { version = "0.56.1", optional = true } aws-sdk-s3 = { version = "0.30.0", optional = true } +aws-smithy-http = { version = "0.56.1", optional = true } +aws-smithy-runtime-api = { version = "0.56.1", optional = true } csv = { version = "1.2", optional = true } tokio = { version = "1", features = ["full"], optional = true } tokio-util = { version = "0.7", optional = true } diff --git a/wrappers/src/fdw/s3_fdw/mod.rs b/wrappers/src/fdw/s3_fdw/mod.rs index 48bd6534..c2a09e0a 100644 --- a/wrappers/src/fdw/s3_fdw/mod.rs +++ b/wrappers/src/fdw/s3_fdw/mod.rs @@ -2,3 +2,68 @@ mod parquet; mod s3_fdw; mod tests; + +use aws_sdk_s3::operation::get_object::GetObjectError; +use aws_smithy_http::result::SdkError; +use aws_smithy_runtime_api::client::orchestrator::HttpResponse; +use pgrx::pg_sys::panic::ErrorReport; +use pgrx::prelude::PgSqlErrorCode; +use thiserror::Error; + +use supabase_wrappers::prelude::{CreateRuntimeError, OptionsError}; + +#[derive(Error, Debug)] +enum S3FdwError { + #[error("invalid s3 uri: {0}")] + InvalidS3Uri(String), + + #[error("invalid format option: '{0}', it can only be 'csv', 'jsonl' or 'parquet'")] + InvalidFormatOption(String), + + #[error("invalid compression option: {0}")] + InvalidCompressOption(String), + + #[error("read line failed: {0}")] + ReadLineError(#[from] std::io::Error), + + #[error("read csv record failed: {0}")] + ReadCsvError(#[from] csv::Error), + + #[error("read jsonl record failed: {0}")] + ReadJsonlError(String), + + #[error("read parquet failed: {0}")] + ReadParquetError(#[from] ::parquet::errors::ParquetError), + + #[error("column '{0}' data type is not supported")] + UnsupportedColumnType(String), + + #[error("column '{0}' data type not match")] + ColumnTypeNotMatch(String), + + #[error("column {0} not found in parquet file")] + ColumnNotFound(String), + + #[error("{0}")] + OptionsError(#[from] OptionsError), + + #[error("{0}")] + CreateRuntimeError(#[from] CreateRuntimeError), + + #[error("parse uri failed: {0}")] + UriParseError(#[from] http::uri::InvalidUri), + + #[error("request failed: {0}")] + RequestError(#[from] SdkError), + + #[error("parse JSON response failed: {0}")] + JsonParseError(#[from] serde_json::Error), +} + +impl From for ErrorReport { + fn from(value: S3FdwError) -> Self { + ErrorReport::new(PgSqlErrorCode::ERRCODE_FDW_ERROR, format!("{value}"), "") + } +} + +type S3FdwResult = Result; diff --git a/wrappers/src/fdw/s3_fdw/parquet.rs b/wrappers/src/fdw/s3_fdw/parquet.rs index 56137949..9380ccb7 100644 --- a/wrappers/src/fdw/s3_fdw/parquet.rs +++ b/wrappers/src/fdw/s3_fdw/parquet.rs @@ -9,7 +9,7 @@ use parquet::arrow::async_reader::{ use parquet::arrow::ProjectionMask; use pgrx::datum::datetime_support::to_timestamp; use pgrx::pg_sys; -use pgrx::prelude::{Date, PgSqlErrorCode}; +use pgrx::prelude::Date; use std::cmp::min; use std::io::{Cursor, Error as IoError, ErrorKind, Result as IoResult, SeekFrom}; use std::pin::Pin; @@ -19,6 +19,8 @@ use tokio::runtime::Handle; use supabase_wrappers::prelude::*; +use super::{S3FdwError, S3FdwResult}; + // convert an error to IO error #[inline] fn to_io_error(err: impl std::error::Error) -> IoError { @@ -143,11 +145,12 @@ impl S3Parquet { const FDW_NAME: &str = "S3Fdw"; // open batch stream from local buffer - pub(super) async fn open_local_stream(&mut self, buf: Vec) { + pub(super) async fn open_local_stream(&mut self, buf: Vec) -> S3FdwResult<()> { let cursor: Box = Box::new(Cursor::new(buf)); - let builder = ParquetRecordBatchStreamBuilder::new(cursor).await.unwrap(); + let builder = ParquetRecordBatchStreamBuilder::new(cursor).await?; let stream = builder.build().unwrap(); self.stream = Some(stream); + Ok(()) } // open async record batch stream @@ -160,7 +163,7 @@ impl S3Parquet { bucket: &str, object: &str, tgt_cols: &[Column], - ) { + ) -> S3FdwResult<()> { let handle = Handle::current(); let rdr = S3ParquetReader::new(client, bucket, object); @@ -201,72 +204,65 @@ impl S3Parquet { let mask = ProjectionMask::roots(schema, project_indexes); builder.with_projection(mask).build() }) - .map_err(to_io_error) - .unwrap(); + .map_err(|err| parquet::errors::ParquetError::General(err.to_string()))?; self.stream = Some(stream); self.batch = None; self.batch_idx = 0; + + Ok(()) } // refill record batch - pub(super) async fn refill(&mut self) -> Option<()> { + pub(super) async fn refill(&mut self) -> S3FdwResult> { // if there are still records in the batch if let Some(batch) = &self.batch { if self.batch_idx < batch.num_rows() { - return Some(()); + return Ok(Some(())); } } // otherwise, read one moe batch if let Some(ref mut stream) = &mut self.stream { - match stream.try_next().await { - Ok(result) => { - return result.map(|batch| { - stats::inc_stats( - Self::FDW_NAME, - stats::Metric::RowsIn, - batch.num_rows() as i64, - ); - stats::inc_stats( - Self::FDW_NAME, - stats::Metric::BytesIn, - batch.get_array_memory_size() as i64, - ); + let result = stream.try_next().await?; + return Ok(result.map(|batch| { + stats::inc_stats( + Self::FDW_NAME, + stats::Metric::RowsIn, + batch.num_rows() as i64, + ); + stats::inc_stats( + Self::FDW_NAME, + stats::Metric::BytesIn, + batch.get_array_memory_size() as i64, + ); - self.batch = Some(batch); - self.batch_idx = 0; - }) - } - Err(err) => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("read parquet record batch failed: {}", err), - ); - return None; - } - } + self.batch = Some(batch); + self.batch_idx = 0; + })); } - None + Ok(None) } // read one row from record batch - pub(super) fn read_into_row(&mut self, row: &mut Row, tgt_cols: &Vec) -> Option<()> { + pub(super) fn read_into_row( + &mut self, + row: &mut Row, + tgt_cols: &Vec, + ) -> S3FdwResult> { if let Some(batch) = &self.batch { for tgt_col in tgt_cols { let col = batch .column_by_name(&tgt_col.name) - .unwrap_or_else(|| panic!("column {} not found in parquet file", tgt_col.name)); + .ok_or(S3FdwError::ColumnNotFound(tgt_col.name.clone()))?; macro_rules! col_to_cell { ($array_type:ident, $cell_type:ident) => {{ let arr = col .as_any() .downcast_ref::() - .unwrap_or_else(|| { - panic!("column '{}' data type not match", tgt_col.name) - }); + .ok_or(S3FdwError::ColumnTypeNotMatch(tgt_col.name.clone()))?; if arr.is_null(self.batch_idx) { None } else { @@ -287,9 +283,7 @@ impl S3Parquet { let arr = col .as_any() .downcast_ref::() - .unwrap_or_else(|| { - panic!("column '{}' data type not match", tgt_col.name) - }); + .ok_or(S3FdwError::ColumnTypeNotMatch(tgt_col.name.clone()))?; if arr.is_null(self.batch_idx) { None } else { @@ -302,9 +296,7 @@ impl S3Parquet { let arr = col .as_any() .downcast_ref::() - .unwrap_or_else(|| { - panic!("column '{}' data type not match", tgt_col.name) - }); + .ok_or(S3FdwError::ColumnTypeNotMatch(tgt_col.name.clone()))?; if arr.is_null(self.batch_idx) { None } else { @@ -316,9 +308,7 @@ impl S3Parquet { let arr = col .as_any() .downcast_ref::() - .unwrap_or_else(|| { - panic!("column '{}' data type not match", tgt_col.name) - }); + .ok_or(S3FdwError::ColumnTypeNotMatch(tgt_col.name.clone()))?; if arr.is_null(self.batch_idx) { None } else { @@ -335,9 +325,7 @@ impl S3Parquet { let arr = col .as_any() .downcast_ref::() - .unwrap_or_else(|| { - panic!("column '{}' data type not match", tgt_col.name) - }); + .ok_or(S3FdwError::ColumnTypeNotMatch(tgt_col.name.clone()))?; if arr.is_null(self.batch_idx) { None } else { @@ -347,19 +335,13 @@ impl S3Parquet { }) } } - _ => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("column '{}' data type not supported", tgt_col.name), - ); - None - } + _ => return Err(S3FdwError::UnsupportedColumnType(tgt_col.name.clone())), }; row.push(&tgt_col.name, cell); } self.batch_idx += 1; - return Some(()); + return Ok(Some(())); } - None + Ok(None) } } diff --git a/wrappers/src/fdw/s3_fdw/s3_fdw.rs b/wrappers/src/fdw/s3_fdw/s3_fdw.rs index 5cdcf4c4..5de35dc2 100644 --- a/wrappers/src/fdw/s3_fdw/s3_fdw.rs +++ b/wrappers/src/fdw/s3_fdw/s3_fdw.rs @@ -3,19 +3,18 @@ use async_compression::tokio::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZlibD use aws_sdk_s3 as s3; use http::Uri; use pgrx::pg_sys; -use pgrx::pg_sys::panic::ErrorReport; -use pgrx::prelude::PgSqlErrorCode; use serde_json::{self, Value as JsonValue}; use std::collections::{HashMap, VecDeque}; use std::env; use std::io::Cursor; use std::pin::Pin; -use thiserror::Error; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader}; use super::parquet::*; use supabase_wrappers::prelude::*; +use super::{S3FdwError, S3FdwResult}; + // record parser for a S3 file enum Parser { Csv(csv::Reader>>), @@ -25,7 +24,7 @@ enum Parser { } #[wrappers_fdw( - version = "0.1.2", + version = "0.1.3", author = "Supabase", website = "https://github.com/supabase/wrappers/tree/main/wrappers/src/fdw/s3_fdw", error_type = "S3FdwError" @@ -54,9 +53,9 @@ impl S3Fdw { // Returns: // Some - still have records to read // None - no more records - fn refill(&mut self) -> Option<()> { + fn refill(&mut self) -> S3FdwResult> { if !self.buf.is_empty() { - return Some(()); + return Ok(Some(())); } if let Some(ref mut rdr) = self.rdr { @@ -64,21 +63,11 @@ impl S3Fdw { let mut total_lines = 0; let mut total_bytes = 0; loop { - match self.rt.block_on(rdr.read_line(&mut self.buf)) { - Ok(num_bytes) => { - total_lines += 1; - total_bytes += num_bytes; - if num_bytes == 0 || self.buf.len() > Self::BUF_SIZE { - break; - } - } - Err(err) => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("fetch query result failed: {}", err), - ); - return None; - } + let num_bytes = self.rt.block_on(rdr.read_line(&mut self.buf))?; + total_lines += 1; + total_bytes += num_bytes; + if num_bytes == 0 || self.buf.len() > Self::BUF_SIZE { + break; } } @@ -87,7 +76,7 @@ impl S3Fdw { } if self.buf.is_empty() { - return None; + return Ok(None); } match &mut self.parser { @@ -107,44 +96,24 @@ impl S3Fdw { .collect::>() .join(","); let json_str = format!("{{ \"rows\": [{}] }}", s.trim_end_matches(',')); - match serde_json::from_str::(&json_str) { - Ok(rows) => { - *records = - VecDeque::from(rows.get("rows").unwrap().as_array().unwrap().to_vec()); - } - Err(err) => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("parse json line file failed: {}", err), - ); - return None; - } - } + let rows = serde_json::from_str::(&json_str)?; + let rows = rows + .get("rows") + .and_then(|arr| arr.as_array()) + .ok_or(S3FdwError::ReadJsonlError(json_str))?; + *records = VecDeque::from(rows.to_vec()); } _ => unreachable!(), } - Some(()) - } -} - -#[derive(Error, Debug)] -enum S3FdwError { - #[error("{0}")] - OptionsError(#[from] OptionsError), -} - -impl From for ErrorReport { - fn from(value: S3FdwError) -> Self { - match value { - S3FdwError::OptionsError(e) => e.into(), - } + Ok(Some(())) } } impl ForeignDataWrapper for S3Fdw { - fn new(options: &HashMap) -> Result { - let rt = tokio::runtime::Runtime::new().unwrap(); + fn new(options: &HashMap) -> S3FdwResult { + let rt = tokio::runtime::Runtime::new() + .map_err(CreateRuntimeError::FailedToCreateAsyncRuntime)?; let mut ret = S3Fdw { rt, client: None, @@ -169,7 +138,7 @@ impl ForeignDataWrapper for S3Fdw { let vault_secret_access_key = require_option("vault_secret_access_key", options)?; get_vault_secret(vault_access_key_id) - .zip(get_vault_secret(&vault_secret_access_key)) + .zip(get_vault_secret(vault_secret_access_key)) } None => { // if using credentials directly specified @@ -227,33 +196,18 @@ impl ForeignDataWrapper for S3Fdw { _sorts: &[Sort], _limit: &Option, options: &HashMap, - ) -> Result<(), S3FdwError> { + ) -> S3FdwResult<()> { // extract s3 bucket and object path from uri option let (bucket, object) = { - let uri = require_option("uri", options)?; - match uri.parse::() { - Ok(uri) => { - if uri.scheme_str() != Option::Some("s3") - || uri.host().is_none() - || uri.path().is_empty() - { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("invalid s3 uri: {}", uri), - ); - return Ok(()); - } - // exclude 1st "/" char in the path as s3 object path doesn't like it - (uri.host().unwrap().to_owned(), uri.path()[1..].to_string()) - } - Err(err) => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("parse s3 uri failed: {}", err), - ); - return Ok(()); - } + let uri = require_option("uri", options)?.parse::()?; + if uri.scheme_str() != Option::Some("s3") + || uri.host().is_none() + || uri.path().is_empty() + { + return Err(S3FdwError::InvalidS3Uri(uri.to_string())); } + // exclude 1st "/" char in the path as s3 object path doesn't like it + (uri.host().unwrap().to_owned(), uri.path()[1..].to_string()) }; let has_header: bool = options.get("has_header") == Some(&"true".to_string()); @@ -268,31 +222,14 @@ impl ForeignDataWrapper for S3Fdw { "csv" => self.parser = Parser::Csv(csv::Reader::from_reader(Cursor::new(vec![0]))), "jsonl" => self.parser = Parser::JsonLine(VecDeque::new()), "parquet" => self.parser = Parser::Parquet(S3Parquet::default()), - _ => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!( - "invalid format option: {}, it can only be 'csv', 'jsonl' or 'parquet'", - format - ), - ); - return Ok(()); - } + _ => return Err(S3FdwError::InvalidFormatOption(format.to_string())), } - let stream = match self + let stream = self .rt - .block_on(client.get_object().bucket(&bucket).key(&object).send()) - { - Ok(resp) => resp.body.into_async_read(), - Err(err) => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("request s3 failed: {}", err), - ); - return Ok(()); - } - }; + .block_on(client.get_object().bucket(&bucket).key(&object).send())? + .body + .into_async_read(); let mut boxed_stream: Pin> = if let Some(compress) = options.get("compress") { @@ -302,13 +239,7 @@ impl ForeignDataWrapper for S3Fdw { "gzip" => Box::pin(GzipDecoder::new(buf_rdr)), "xz" => Box::pin(XzDecoder::new(buf_rdr)), "zlib" => Box::pin(ZlibDecoder::new(buf_rdr)), - _ => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("invalid compression option: {}", compress), - ); - return Ok(()); - } + _ => return Err(S3FdwError::InvalidCompressOption(compress.to_string())), } } else { Box::pin(stream) @@ -323,7 +254,7 @@ impl ForeignDataWrapper for S3Fdw { self.rt .block_on(boxed_stream.read_to_end(&mut buf)) .expect("read compressed parquet file failed"); - self.rt.block_on(s3parquet.open_local_stream(buf)); + self.rt.block_on(s3parquet.open_local_stream(buf))?; } else { // open async read stream self.rt.block_on(s3parquet.open_async_stream( @@ -331,7 +262,7 @@ impl ForeignDataWrapper for S3Fdw { &bucket, &object, &self.tgt_cols, - )); + ))?; } return Ok(()); } @@ -342,13 +273,7 @@ impl ForeignDataWrapper for S3Fdw { if let Parser::Csv(_) = self.parser { if has_header { let mut header = String::new(); - if let Err(err) = self.rt.block_on(rdr.read_line(&mut header)) { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("fetch csv file failed: {}", err), - ); - return Ok(()); - } + self.rt.block_on(rdr.read_line(&mut header))?; } } @@ -358,13 +283,13 @@ impl ForeignDataWrapper for S3Fdw { Ok(()) } - fn iter_scan(&mut self, row: &mut Row) -> Result, S3FdwError> { + fn iter_scan(&mut self, row: &mut Row) -> S3FdwResult> { // read parquet record if let Parser::Parquet(ref mut s3parquet) = &mut self.parser { - if self.rt.block_on(s3parquet.refill()).is_none() { + if self.rt.block_on(s3parquet.refill())?.is_none() { return Ok(None); } - let ret = s3parquet.read_into_row(row, &self.tgt_cols); + let ret = s3parquet.read_into_row(row, &self.tgt_cols)?; if ret.is_some() { self.rows_out += 1; } else { @@ -375,7 +300,7 @@ impl ForeignDataWrapper for S3Fdw { // read csv or jsonl record loop { - if self.refill().is_none() { + if self.refill()?.is_none() { break; } @@ -383,28 +308,17 @@ impl ForeignDataWrapper for S3Fdw { match &mut self.parser { Parser::Csv(rdr) => { let mut record = csv::StringRecord::new(); - match rdr.read_record(&mut record) { - Ok(result) => { - if result { - for col in &self.tgt_cols { - let cell = - record.get(col.num - 1).map(|s| Cell::String(s.to_owned())); - row.push(&col.name, cell); - } - self.rows_out += 1; - return Ok(Some(())); - } else { - // no more records left in the local buffer, refill from remote - self.buf.clear(); - } - } - Err(err) => { - report_error( - PgSqlErrorCode::ERRCODE_FDW_ERROR, - &format!("read csv record failed: {}", err), - ); - break; + let result = rdr.read_record(&mut record)?; + if result { + for col in &self.tgt_cols { + let cell = record.get(col.num - 1).map(|s| Cell::String(s.to_owned())); + row.push(&col.name, cell); } + self.rows_out += 1; + return Ok(Some(())); + } else { + // no more records left in the local buffer, refill from remote + self.buf.clear(); } } Parser::JsonLine(records) => { @@ -452,17 +366,14 @@ impl ForeignDataWrapper for S3Fdw { Ok(None) } - fn end_scan(&mut self) -> Result<(), S3FdwError> { + fn end_scan(&mut self) -> S3FdwResult<()> { // release local resources self.rdr.take(); self.parser = Parser::JsonLine(VecDeque::new()); Ok(()) } - fn validator( - options: Vec>, - catalog: Option, - ) -> Result<(), S3FdwError> { + fn validator(options: Vec>, catalog: Option) -> S3FdwResult<()> { if let Some(oid) = catalog { if oid == FOREIGN_TABLE_RELATION_ID { check_options_contain(&options, "uri")?;