Skip to content

Commit

Permalink
improve array params handling in postgres (#4605)
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoCasa authored Oct 30, 2024
1 parent f2db73b commit 1ff221f
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 102 deletions.
255 changes: 155 additions & 100 deletions backend/windmill-worker/src/pg_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@ use serde_json::value::RawValue;
use serde_json::Map;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio_postgres::types::IsNull;
use tokio_postgres::Client;
use tokio_postgres::{
types::{to_sql_checked, ToSql},
NoTls, Row,
};
use tokio_postgres::{types::ToSql, NoTls, Row};
use tokio_postgres::{
types::{FromSql, Type},
Column,
Expand All @@ -41,7 +37,7 @@ use windmill_queue::CanceledBy;
use crate::common::{build_args_values, sizeof_val, OccupancyMetrics};
use crate::handle_child::run_future_with_polling_update_job_poller;
use crate::{AuthedClientBackgroundTask, MAX_RESULT_SIZE};
use bytes::{Buf, BytesMut};
use bytes::Buf;
use lazy_static::lazy_static;
use urlencoding::encode;

Expand Down Expand Up @@ -98,6 +94,11 @@ fn do_postgresql_inner<'a>(

let mut res: Vec<serde_json::Value> = vec![];

let query_params = query_params
.iter()
.map(|p| &**p as &(dyn ToSql + Sync))
.collect_vec();

if skip_collect {
client
.execute_raw(&query, query_params)
Expand Down Expand Up @@ -414,133 +415,176 @@ pub async fn do_postgresql(
return Ok(raw_result);
}

#[derive(Debug)]
enum PgType {
String(String),
Bool(bool),
I8(i8),
I16(i16),
I32(i32),
I64(i64),
U32(u32),
F32(f32),
F64(f64),
Uuid(Uuid),
Decimal(Decimal),
Date(chrono::NaiveDate),
Time(chrono::NaiveTime),
Timestamp(chrono::NaiveDateTime),
None(Option<bool>),
Array(Vec<PgType>),
Json(serde_json::Value),
Bytea(Vec<u8>),
fn map_as_single_type<T>(
vec: &Vec<Value>,
f: impl Fn(&Value) -> Option<T>,
) -> anyhow::Result<Vec<Option<T>>> {
vec.into_iter()
.map(|v| {
// allow nulls in arrays
if matches!(v, Value::Null) {
Some(None)
} else {
f(v).map(Some)
}
})
.collect::<Option<Vec<Option<T>>>>()
.ok_or_else(|| anyhow::anyhow!("Mixed types in array"))
}

impl ToSql for PgType {
fn to_sql(
&self,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
match *self {
PgType::String(ref val) => val.to_sql(ty, out),
PgType::Bool(ref val) => val.to_sql(ty, out),
PgType::I8(ref val) => val.to_sql(ty, out),
PgType::I16(ref val) => val.to_sql(ty, out),
PgType::I32(ref val) => val.to_sql(ty, out),
PgType::I64(ref val) => val.to_sql(ty, out),
PgType::U32(ref val) => val.to_sql(ty, out),
PgType::F32(ref val) => val.to_sql(ty, out),
PgType::F64(ref val) => val.to_sql(ty, out),
PgType::Uuid(ref val) => val.to_sql(ty, out),
PgType::Decimal(ref val) => val.to_sql(ty, out),
PgType::Date(ref val) => val.to_sql(ty, out),
PgType::Time(ref val) => val.to_sql(ty, out),
PgType::Timestamp(ref val) => val.to_sql(ty, out),
PgType::None(ref val) => val.to_sql(ty, out),
PgType::Array(ref val) => val.to_sql(ty, out),
PgType::Json(ref val) => val.to_sql(ty, out),
PgType::Bytea(ref val) => val.to_sql(ty, out),
fn convert_vec_val(
vec: &Vec<Value>,
arg_t: &String,
) -> windmill_common::error::Result<Box<dyn ToSql + Sync + Send>> {
match arg_t.as_str() {
"bool" | "boolean" => Ok(Box::new(map_as_single_type(vec, |v| v.as_bool())?)),
"char" | "character" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_i64().map(|x| x as i8)
})?)),
"smallint" | "smallserial" | "int2" | "serial2" => {
Ok(Box::new(map_as_single_type(vec, |v| {
v.as_i64().map(|x| x as i16)
})?))
}
"int" | "integer" | "int4" | "serial" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_i64().map(|x| x as i32)
})?)),
"numeric" | "decimal" => Ok(Box::new(map_as_single_type(vec, |v| {
if v.is_i64() {
Decimal::from_i64(v.as_i64().unwrap())
} else if v.is_f64() {
Decimal::from_f64(v.as_f64().unwrap())
} else {
None
}
})?)),
"oid" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_u64().map(|x| x as u32)
})?)),
"bigint" | "bigserial" | "int8" | "serial8" => {
Ok(Box::new(map_as_single_type(vec, |v| {
v.as_u64().map(|x| x as i64)
})?))
}
"real" | "float4" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_f64().map(|x| x as f32)
})?)),
"double" | "float8" => Ok(Box::new(map_as_single_type(vec, |v| v.as_f64())?)),
"uuid" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_str().map(|x| Uuid::parse_str(x).ok()).flatten()
})?)),
"date" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_str().map(|x| {
chrono::NaiveDate::parse_from_str(x, "%Y-%m-%dT%H:%M:%S.%3fZ").unwrap_or_default()
})
})?)),
"time" | "timetz" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_str().map(|x| {
chrono::NaiveTime::parse_from_str(x, "%Y-%m-%dT%H:%M:%S.%3fZ").unwrap_or_default()
})
})?)),
"timestamp" | "timestamptz" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_str().map(|x| {
chrono::NaiveDateTime::parse_from_str(x, "%Y-%m-%dT%H:%M:%S.%3fZ")
.unwrap_or_default()
})
})?)),
"jsonb" | "json" => Ok(Box::new(vec.clone().into_iter().map(Some).collect_vec())),
"bytea" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_str().map(|x| {
engine::general_purpose::STANDARD
.decode(x)
.unwrap_or(vec![])
})
})?)),
"text" | "varchar" => Ok(Box::new(map_as_single_type(vec, |v| {
v.as_str().map(|x| x.to_string())
})?)),
_ => Err(anyhow::anyhow!("Unsupported JSON array type"))?,
}

fn accepts(_: &Type) -> bool {
true
}

to_sql_checked!();
}

fn convert_val(value: &Value, arg_t: &String, typ: &Typ) -> windmill_common::error::Result<PgType> {
fn convert_val(
value: &Value,
arg_t: &String,
typ: &Typ,
) -> windmill_common::error::Result<Box<dyn ToSql + Sync + Send>> {
match value {
Value::Array(vec) if arg_t.ends_with("[]") => {
let arg_t = arg_t.trim_end_matches("[]").to_string();
let mut result = vec![];
for val in vec {
result.push(convert_val(val, &arg_t, typ)?);
}
Ok(PgType::Array(result))
}
Value::Null => Ok(PgType::None(None::<bool>)),
Value::Bool(b) => Ok(PgType::Bool(b.clone())),
Value::Number(n) if matches!(typ, Typ::Str(_)) => Ok(PgType::String(n.to_string())),
Value::Number(n) if n.is_i64() && arg_t == "char" => {
Ok(PgType::I8(n.as_i64().unwrap() as i8))
convert_vec_val(vec, &arg_t)
}
Value::Number(n) if n.is_i64() && (arg_t == "smallint" || arg_t == "smallserial") => {
Ok(PgType::I16(n.as_i64().unwrap() as i16))
Value::Null => Ok(Box::new(None::<bool>)),
Value::Bool(b) => Ok(Box::new(b.clone())),
Value::Number(n) if matches!(typ, Typ::Str(_)) => Ok(Box::new(n.to_string())),
Value::Number(n) if arg_t == "char" && n.is_i64() => {
Ok(Box::new(n.as_i64().unwrap() as i8))
}
Value::Number(n)
if n.is_i64()
&& (arg_t == "int"
|| arg_t == "integer"
|| arg_t == "int4"
|| arg_t == "serial") =>
if (arg_t == "smallint"
|| arg_t == "smallserial"
|| arg_t == "int2"
|| arg_t == "serial2")
&& n.is_i64() =>
{
Ok(PgType::I32(n.as_i64().unwrap() as i32))
Ok(Box::new(n.as_i64().unwrap() as i16))
}
Value::Number(n) if n.is_i64() && (arg_t == "numeric" || arg_t == "decimal") => Ok(
PgType::Decimal(Decimal::from_i64(n.as_i64().unwrap()).unwrap()),
),
Value::Number(n) if n.is_i64() => Ok(PgType::I64(n.as_i64().unwrap())),
Value::Number(n) if n.is_u64() && arg_t == "oid" => {
Ok(PgType::U32(n.as_u64().unwrap() as u32))
Value::Number(n)
if (arg_t == "int" || arg_t == "integer" || arg_t == "int4" || arg_t == "serial")
&& n.is_i64() =>
{
Ok(Box::new(n.as_i64().unwrap() as i32))
}
Value::Number(n) if n.is_u64() && (arg_t == "bigint" || arg_t == "bigserial") => {
Ok(PgType::I64(n.as_u64().unwrap() as i64))
Value::Number(n) if (arg_t == "real" || arg_t == "float4") && n.as_f64().is_some() => {
Ok(Box::new(n.as_f64().unwrap() as f32))
}
Value::Number(n) if n.is_f64() && arg_t == "real" => {
Ok(PgType::F32(n.as_f64().unwrap() as f32))
Value::Number(n) if (arg_t == "double" || arg_t == "float8") && n.as_f64().is_some() => {
Ok(Box::new(n.as_f64().unwrap()))
}
Value::Number(n) if n.is_f64() && arg_t == "double" => Ok(PgType::F64(n.as_f64().unwrap())),
Value::Number(n) if n.is_f64() && (arg_t == "numeric" || arg_t == "decimal") => Ok(
PgType::Decimal(Decimal::from_f64(n.as_f64().unwrap()).unwrap()),
Value::Number(n) if (arg_t == "numeric" || arg_t == "decimal") && n.is_i64() => Ok(
Box::new(Decimal::from_i64(n.as_i64().unwrap()).unwrap_or_default()),
),
Value::Number(n) if (arg_t == "numeric" || arg_t == "decimal") && n.is_f64() => Ok(
Box::new(Decimal::from_f64(n.as_f64().unwrap()).unwrap_or_default()),
),
Value::Number(n) => Ok(PgType::F64(n.as_f64().unwrap())),
Value::String(s) if arg_t == "uuid" => Ok(PgType::Uuid(Uuid::parse_str(s)?)),
Value::Number(n) if arg_t == "oid" && n.is_u64() => {
Ok(Box::new(n.as_u64().unwrap() as u32))
}
Value::Number(n)
if (arg_t == "bigint"
|| arg_t == "bigserial"
|| arg_t == "int8"
|| arg_t == "serial8")
&& n.is_u64() =>
{
Ok(Box::new(n.as_u64().unwrap() as i64))
}
Value::Number(n) if n.is_i64() => Ok(Box::new(n.as_i64().unwrap())),
Value::Number(n) => Ok(Box::new(n.as_f64().unwrap())),
Value::String(s) if arg_t == "uuid" => Ok(Box::new(Uuid::parse_str(s)?)),
Value::String(s) if arg_t == "date" => {
let date =
chrono::NaiveDate::parse_from_str(s, "%Y-%m-%dT%H:%M:%S.%3fZ").unwrap_or_default();
Ok(PgType::Date(date))
Ok(Box::new(date))
}
Value::String(s) if arg_t == "time" || arg_t == "timetz" => {
let time =
chrono::NaiveTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S.%3fZ").unwrap_or_default();
Ok(PgType::Time(time))
Ok(Box::new(time))
}
Value::String(s) if arg_t == "timestamp" || arg_t == "timestamptz" => {
let datetime = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S.%3fZ")
.unwrap_or_default();
Ok(PgType::Timestamp(datetime))
Ok(Box::new(datetime))
}
Value::String(s) if arg_t == "bytea" => {
let bytes = engine::general_purpose::STANDARD
.decode(s)
.unwrap_or(vec![]);
Ok(PgType::Bytea(bytes))
Ok(Box::new(bytes))
}
Value::Object(_) => Ok(PgType::Json(value.clone())),
Value::String(s) => Ok(PgType::String(s.clone())),
Value::Object(_) => Ok(Box::new(value.clone())),
Value::String(s) => Ok(Box::new(s.clone())),
_ => Err(Error::ExecutionErr(format!(
"Unsupported type in query: {:?} and signature {arg_t:?}",
value
Expand Down Expand Up @@ -624,6 +668,9 @@ pub fn pg_cell_to_json_value(
Type::TS_VECTOR => get_basic(row, column, column_i, |a: StringCollector| {
Ok(JSONValue::String(a.0))
})?,
Type::OID => get_basic(row, column, column_i, |a: u32| {
Ok(JSONValue::Number(serde_json::Number::from(a)))
})?,
// array types
Type::BOOL_ARRAY => get_array(row, column, column_i, |a: bool| Ok(JSONValue::Bool(a)))?,
Type::BIT_ARRAY => get_array(row, column, column_i, |a: bit_vec::BitVec| match a.len() {
Expand Down Expand Up @@ -655,6 +702,10 @@ pub fn pg_cell_to_json_value(
Type::FLOAT8_ARRAY => {
get_array(row, column, column_i, |a: f64| Ok(f64_to_json_number(a)?))?
}
Type::NUMERIC_ARRAY => get_array(row, column, column_i, |a: Decimal| {
Ok(serde_json::to_value(a)
.map_err(|_| anyhow::anyhow!("Cannot convert decimal to json"))?)
})?,
// these types require a custom StringCollector struct as an intermediary (see struct at bottom)
Type::TS_VECTOR_ARRAY => get_array(row, column, column_i, |a: StringCollector| {
Ok(JSONValue::String(a.0))
Expand Down Expand Up @@ -766,7 +817,7 @@ fn get_array<'a, T: FromSql<'a>>(
val_to_json_val: impl Fn(T) -> Result<JSONValue, Error>,
) -> Result<JSONValue, Error> {
let raw_val_array = row
.try_get::<_, Option<Vec<T>>>(column_i)
.try_get::<_, Option<Vec<Option<T>>>>(column_i)
.with_context(|| {
format!(
"conversion issue for array at column_name `{}`",
Expand All @@ -777,7 +828,11 @@ fn get_array<'a, T: FromSql<'a>>(
Some(val_array) => {
let mut result = vec![];
for val in val_array {
result.push(val_to_json_val(val)?);
result.push(
val.map(|v| val_to_json_val(v))
.transpose()?
.unwrap_or(Value::Null),
);
}
JSONValue::Array(result)
}
Expand Down
8 changes: 6 additions & 2 deletions frontend/src/lib/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ export const POSTGRES_TYPES = [
'BIGSERIAL[]',
'REAL',
'REAL[]',
'DOUBLE PRECISION',
'DOUBLE PRECISION[]',
'FLOAT8',
'FLOAT8[]',
'NUMERIC',
'NUMERIC[]',
'DECIMAL',
Expand All @@ -77,8 +77,12 @@ export const POSTGRES_TYPES = [
'DATE[]',
'TIME',
'TIME[]',
'TIMETZ',
'TIMETZ[]',
'TIMESTAMP',
'TIMESTAMP[]',
'TIMESTAMPTZ',
'TIMESTAMPTZ[]',
'JSON',
'JSON[]',
'JSONB',
Expand Down

0 comments on commit 1ff221f

Please sign in to comment.