Skip to content

Commit

Permalink
feat: check constraints, removed physical test table, rewrote tests t…
Browse files Browse the repository at this point in the history
…o not need a physical table, addressed most (but not all) PR comments
  • Loading branch information
scarman-db committed Dec 8, 2023
1 parent 811a6d4 commit e906f3c
Show file tree
Hide file tree
Showing 71 changed files with 134 additions and 100 deletions.
32 changes: 19 additions & 13 deletions crates/deltalake-core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ use arrow::datatypes::{DataType as ArrowDataType, Schema as ArrowSchema, SchemaR
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use arrow_array::types::UInt16Type;
use arrow_array::{DictionaryArray, StringArray};
use arrow_array::{Array, DictionaryArray, StringArray};
use arrow_cast::display::array_value_to_string;

use arrow_schema::Field;
use async_trait::async_trait;
use chrono::{NaiveDateTime, TimeZone, Utc};
Expand Down Expand Up @@ -65,6 +67,7 @@ use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use datafusion_proto::logical_plan::LogicalExtensionCodec;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use itertools::Itertools;
use log::error;
use object_store::ObjectMeta;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -1024,7 +1027,7 @@ pub struct DeltaDataChecker {

impl DeltaDataChecker {
/// Create a new DeltaDataChecker with a specified set of invariants
pub fn with_invariants(invariants: Vec<Invariant>) -> Self {
pub fn new_with_invariants(invariants: Vec<Invariant>) -> Self {
Self {
invariants,
constraints: vec![],
Expand All @@ -1033,7 +1036,7 @@ impl DeltaDataChecker {
}

/// Create a new DeltaDataChecker with a specified set of constraints
pub fn with_constraints(constraints: Vec<Constraint>) -> Self {
pub fn new_with_constraints(constraints: Vec<Constraint>) -> Self {
Self {
constraints,
invariants: vec![],
Expand Down Expand Up @@ -1075,11 +1078,8 @@ impl DeltaDataChecker {
if checks.is_empty() {
return Ok(());
}

// if !self.ctx.table_exist("data")? {
let table = MemTable::try_new(record_batch.schema(), vec![vec![record_batch.clone()]])?;
self.ctx.register_table("data", Arc::new(table))?;
// }

let mut violations: Vec<String> = Vec::new();

Expand All @@ -1098,9 +1098,15 @@ impl DeltaDataChecker {

let dfs: Vec<RecordBatch> = self.ctx.sql(&sql).await?.collect().await?;
if !dfs.is_empty() && dfs[0].num_rows() > 0 {
let value = format!("{:?}", dfs[0].column(0));
let values: String = dfs[0]
.columns()
.iter()
.map(|c| array_value_to_string(c, 0).unwrap_or(String::from("null")))
.join(", ");

let value = format!("{}", values);
let msg = format!(
"Check or Invariant ({}) violated by value {}",
"Check or Invariant ({}) violated by value in row: [{}]",
check.get_expression(),
value
);
Expand Down Expand Up @@ -1677,7 +1683,7 @@ mod tests {
.unwrap();
// Empty invariants is okay
let invariants: Vec<Invariant> = vec![];
assert!(DeltaDataChecker::with_invariants(invariants)
assert!(DeltaDataChecker::new_with_invariants(invariants)
.check_batch(&batch)
.await
.is_ok());
Expand All @@ -1687,7 +1693,7 @@ mod tests {
Invariant::new("a", "a is not null"),
Invariant::new("b", "b < 1000"),
];
assert!(DeltaDataChecker::with_invariants(invariants)
assert!(DeltaDataChecker::new_with_invariants(invariants)
.check_batch(&batch)
.await
.is_ok());
Expand All @@ -1697,7 +1703,7 @@ mod tests {
Invariant::new("a", "a is null"),
Invariant::new("b", "b < 100"),
];
let result = DeltaDataChecker::with_invariants(invariants)
let result = DeltaDataChecker::new_with_invariants(invariants)
.check_batch(&batch)
.await;
assert!(result.is_err());
Expand All @@ -1708,7 +1714,7 @@ mod tests {

// Irrelevant invariants return a different error
let invariants = vec![Invariant::new("c", "c > 2000")];
let result = DeltaDataChecker::with_invariants(invariants)
let result = DeltaDataChecker::new_with_invariants(invariants)
.check_batch(&batch)
.await;
assert!(result.is_err());
Expand All @@ -1724,7 +1730,7 @@ mod tests {
let batch = RecordBatch::try_new(schema, vec![inner]).unwrap();

let invariants = vec![Invariant::new("x.b", "x.b < 1000")];
let result = DeltaDataChecker::with_invariants(invariants)
let result = DeltaDataChecker::new_with_invariants(invariants)
.check_batch(&batch)
.await;
assert!(result.is_err());
Expand Down
173 changes: 114 additions & 59 deletions crates/deltalake-core/src/operations/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use futures::future::BoxFuture;
use futures::StreamExt;
use serde_json::json;

use crate::delta_datafusion::{find_files, register_store, DeltaDataChecker, DeltaScanBuilder};
use crate::delta_datafusion::{register_store, DeltaDataChecker, DeltaScanBuilder};
use crate::kernel::{Action, CommitInfo, IsolationLevel, Metadata, Protocol};
use crate::logstore::LogStoreRef;
use crate::operations::collect_sendable_stream;
use crate::operations::datafusion_utils::Expression;
use crate::operations::transaction::commit;
use crate::protocol::DeltaOperation;
Expand Down Expand Up @@ -68,21 +68,21 @@ impl std::future::IntoFuture for ConstraintBuilder {
type IntoFuture = BoxFuture<'static, Self::Output>;

fn into_future(self) -> Self::IntoFuture {
let this = self;
let mut this = self;

Box::pin(async move {
if this.name.is_none() {
return Err(DeltaTableError::Generic("No name provided".to_string()));
} else if this.expr.is_none() {
return Err(DeltaTableError::Generic(
"No expression provided".to_string(),
));
}
let name = this.name.unwrap();
let name = match this.name {
Some(v) => v,
None => return Err(DeltaTableError::Generic("No name provided".to_string())),
};
let expr = match this.expr {
Some(Expression::String(s)) => s,
Some(Expression::DataFusion(e)) => e.to_string(),
None => unreachable!(),
None => {
return Err(DeltaTableError::Generic(
"No expression provided".to_string(),
))
}
};

let mut metadata = this
Expand All @@ -106,22 +106,36 @@ impl std::future::IntoFuture for ConstraintBuilder {
});

// Checker built here with the one time constraint to check.
let checker = DeltaDataChecker::with_constraints(vec![Constraint::new("*", &expr)]);

let files_to_check =
find_files(&this.snapshot, this.log_store.clone(), &state, None).await?;
let checker = DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr)]);
let scan = DeltaScanBuilder::new(&this.snapshot, this.log_store.clone(), &state)
.with_files(&files_to_check.candidates)
.build()
.await?;

let task_ctx = Arc::new(TaskContext::from(&state));
let record_stream: SendableRecordBatchStream = scan.execute(0, task_ctx)?;
let records = collect_sendable_stream(record_stream).await?;

for batch in records {
checker.check_batch(&batch).await?;
let plan: Arc<dyn ExecutionPlan> = Arc::new(scan);
let mut tasks = vec![];
for p in 0..plan.output_partitioning().partition_count() {
let inner_plan = plan.clone();
let inner_checker = checker.clone();
let task_ctx = Arc::new(TaskContext::from(&state));
let mut record_stream: SendableRecordBatchStream =
inner_plan.execute(p, task_ctx)?;
let handle: tokio::task::JoinHandle<DeltaResult<()>> =
tokio::task::spawn(async move {
while let Some(maybe_batch) = record_stream.next().await {
let batch = maybe_batch?;
inner_checker.check_batch(&batch).await?;
}
Ok(())
});
tasks.push(handle);
}
futures::future::join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|err| DeltaTableError::Generic(err.to_string()))?
.into_iter()
.collect::<Result<Vec<_>, _>>()?;

// We have validated the table passes it's constraints, now to add the constraint to
// the table.
Expand All @@ -137,8 +151,8 @@ impl std::future::IntoFuture for ConstraintBuilder {
} else {
1
},
min_writer_version: if old_protocol.min_reader_version > 3 {
old_protocol.min_reader_version
min_writer_version: if old_protocol.min_writer_version > 3 {
old_protocol.min_writer_version
} else {
3
},
Expand All @@ -150,9 +164,15 @@ impl std::future::IntoFuture for ConstraintBuilder {
("name".to_string(), json!(&name)),
("expr".to_string(), json!(&expr)),
]);

let operations = DeltaOperation::AddConstraint {
name: name.clone(),
expr: expr.clone(),
};

let commit_info = CommitInfo {
timestamp: Some(Utc::now().timestamp_millis()),
operation: Some("ADD CONSTRAINT".to_string()),
operation: Some(operations.name().to_string()),
operation_parameters: Some(operational_parameters),
read_version: Some(this.snapshot.version()),
isolation_level: Some(IsolationLevel::Serializable),
Expand All @@ -166,11 +186,6 @@ impl std::future::IntoFuture for ConstraintBuilder {
Action::Protocol(protocol),
];

let operations = DeltaOperation::AddConstraint {
name: name.clone(),
expr: expr.clone(),
};

let _version = commit(
this.log_store.as_ref(),
&actions,
Expand All @@ -180,6 +195,11 @@ impl std::future::IntoFuture for ConstraintBuilder {
)
.await?;

this.snapshot.merge(
DeltaTableState::from_actions(actions, _version)?,
true,
true,
);
Ok(DeltaTable::new_with_state(this.log_store, this.snapshot))
})
}
Expand All @@ -190,71 +210,106 @@ impl std::future::IntoFuture for ConstraintBuilder {
mod tests {
use std::sync::Arc;

use arrow_array::{Array, Int64Array, RecordBatch};
use arrow_array::{Array, Int32Array, RecordBatch, StringArray};

use crate::kernel::StructType;
use crate::DeltaResult;
use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch};
use crate::{DeltaOps, DeltaResult};

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn add_constraint_with_invalid_data() -> DeltaResult<()> {
let table = crate::DeltaOps::try_from_uri("./tests/data/check-constraints").await?;
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;
let table = DeltaOps(write);

let constraint = table
.add_constraint()
.with_constraint("id3", "id < 60")
.with_constraint("id", "value > 5")
.await;
dbg!(&constraint);
assert!(constraint.is_err());
Ok(())
}

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn add_valid_constraint() -> DeltaResult<()> {
let table = crate::DeltaOps::try_from_uri("./tests/data/check-constraints").await?;
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;
let table = DeltaOps(write);

let constraint = table
.add_constraint()
.with_constraint("id2", "id < 1000")
.with_constraint("id", "value < 1000")
.await;
dbg!(&constraint);
assert!(constraint.is_ok());
let version = constraint?.version();
assert_eq!(version, 2);
assert_eq!(version, 1);
Ok(())
}

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn add_conflicting_named_constraint() -> DeltaResult<()> {
let table = crate::DeltaOps::try_from_uri("./tests/data/check-constraints").await?;
let constraint = table
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;
let table = DeltaOps(write);

let new_table = table
.add_constraint()
.with_constraint("id", "id < 60")
.await;
.with_constraint("id", "value < 60")
.await?;

assert!(constraint.is_err());
let new_table = DeltaOps(new_table);
let second_constraint = new_table
.add_constraint()
.with_constraint("id", "value < 10")
.await;
dbg!(&second_constraint);
assert!(second_constraint.is_err());
Ok(())
}

#[cfg(feature = "datafusion")]
#[tokio::test]
async fn write_data_that_violates_constraint() -> DeltaResult<()> {
let table = crate::DeltaOps::try_from_uri("./tests/data/check-constraints").await?;
let metadata = table.0.metadata()?;
let arrow_schema = <arrow::datatypes::Schema as TryFrom<&StructType>>::try_from(
&metadata.schema()?.clone(),
)?;
let invalid_values: Vec<Arc<dyn Array>> = vec![Arc::new(Int64Array::from(vec![-10]))];
let batch = RecordBatch::try_new(Arc::new(arrow_schema), invalid_values)?;
let err = table.write(vec![batch]).await;
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;

let table = DeltaOps(write)
.add_constraint()
.with_constraint("id", "value > 0")
.await?;
let table = DeltaOps(table);
let invalid_values: Vec<Arc<dyn Array>> = vec![
Arc::new(StringArray::from(vec!["A"])),
Arc::new(Int32Array::from(vec![-10])),
Arc::new(StringArray::from(vec!["2021-02-02"])),
];
let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?;
let err = table.write(vec![batch]).await;
dbg!(&err);
assert!(err.is_err());
Ok(())
}

#[tokio::test]
async fn write_data_that_does_not_violate_constraint() -> DeltaResult<()> {
let table = crate::DeltaOps::try_from_uri("./tests/data/check-constraints").await?;
let metadata = table.0.metadata()?;
let arrow_schema = <arrow::datatypes::Schema as TryFrom<&StructType>>::try_from(
&metadata.schema()?.clone(),
)?;
let invalid_values: Vec<Arc<dyn Array>> = vec![Arc::new(Int64Array::from(vec![160]))];
let batch = RecordBatch::try_new(Arc::new(arrow_schema), invalid_values)?;
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;
let table = DeltaOps(write);

let err = table.write(vec![batch]).await;

assert!(err.is_ok());
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit e906f3c

Please sign in to comment.