diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 488c15150e0..d59614f1c89 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -205,7 +205,7 @@ impl ServerMessage for OneOffQueryResponseMessage { .results .into_iter() .map(|table| OneOffTableJson { - table_name: table.head.table_name, + table_name: table.head.table_name.clone(), rows: table.data, }) .collect(), @@ -221,7 +221,7 @@ impl ServerMessage for OneOffQueryResponseMessage { .results .into_iter() .map(|table| OneOffTable { - table_name: table.head.table_name, + table_name: table.head.table_name.clone(), row: table .data .into_iter() diff --git a/crates/core/src/sql/compiler.rs b/crates/core/src/sql/compiler.rs index fbba23c0c85..61d7fa36a16 100644 --- a/crates/core/src/sql/compiler.rs +++ b/crates/core/src/sql/compiler.rs @@ -10,6 +10,7 @@ use spacetimedb_vm::dsl::{db_table, db_table_raw, query}; use spacetimedb_vm::expr::{ColumnOp, CrudExpr, DbType, Expr, QueryExpr, SourceExpr}; use spacetimedb_vm::operator::OpCmp; use std::collections::HashMap; +use std::sync::Arc; use super::ast::TableSchemaView; @@ -182,7 +183,7 @@ fn compile_columns(table: &TableSchema, columns: Vec) -> DbTable { } } DbTable::new( - Header::new(table.table_name.clone(), new, table.get_constraints()), + Arc::new(Header::new(table.table_name.clone(), new, table.get_constraints())), table.table_id, table.table_type, table.table_access, diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index b379e2ff037..b3c8a943303 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -133,7 +133,7 @@ pub(crate) mod tests { use spacetimedb_sats::relation::Header; use spacetimedb_sats::{product, AlgebraicType, ProductType}; use spacetimedb_vm::dsl::{mem_table, scalar}; - use spacetimedb_vm::eval::create_game_data; + use spacetimedb_vm::eval::test_data::create_game_data; use tempfile::TempDir; /// Short-cut for simplify test execution @@ -369,9 +369,27 @@ pub(crate) mod tests { let (db, _tmp_dir) = make_test_db()?; let mut tx = db.begin_mut_tx(IsolationLevel::Serializable); - create_table_with_rows(&db, &mut tx, "Inventory", data.inv.head.into(), &data.inv.data)?; - create_table_with_rows(&db, &mut tx, "Player", data.player.head.into(), &data.player.data)?; - create_table_with_rows(&db, &mut tx, "Location", data.location.head.into(), &data.location.data)?; + create_table_with_rows( + &db, + &mut tx, + "Inventory", + data.inv.head.to_product_type(), + &data.inv.data, + )?; + create_table_with_rows( + &db, + &mut tx, + "Player", + data.player.head.to_product_type(), + &data.player.data, + )?; + create_table_with_rows( + &db, + &mut tx, + "Location", + data.location.head.to_product_type(), + &data.location.data, + )?; db.commit_tx(&ExecutionContext::default(), tx)?; let result = &run_for_testing( diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 2f8b0754720..14204aa0004 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Instant; use crate::db::db_metrics::{DB_METRICS, MAX_QUERY_COMPILE_TIME}; @@ -32,7 +33,7 @@ pub const OP_TYPE_FIELD_NAME: &str = "__op_type"; /// Create a virtual table from a sequence of table updates. /// Add a special column __op_type to distinguish inserts and deletes. #[tracing::instrument(skip_all)] -pub fn to_mem_table_with_op_type(head: Header, table_access: StAccess, data: &DatabaseTableUpdate) -> MemTable { +pub fn to_mem_table_with_op_type(head: Arc
, table_access: StAccess, data: &DatabaseTableUpdate) -> MemTable { let mut t = MemTable::new(head, table_access, vec![]); if let Some(pos) = t.head.find_pos_by_name(OP_TYPE_FIELD_NAME) { @@ -42,11 +43,14 @@ pub fn to_mem_table_with_op_type(head: Header, table_access: StAccess, data: &Da new })); } else { - t.head.fields.push(Column::new( + // TODO(perf): Eliminate this `clone_for_error` call, as we're not in an error path. + let mut head = t.head.clone_for_error(); + head.fields.push(Column::new( FieldName::named(&t.head.table_name, OP_TYPE_FIELD_NAME), AlgebraicType::U8, t.head.fields.len().into(), )); + t.head = Arc::new(head); for row in &data.ops { let mut new = row.row.clone(); new.elements.push(row.op_type.into()); diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index 1ee95280535..71dfcbbb8b0 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -27,6 +27,7 @@ use anyhow::Context; use derive_more::{Deref, DerefMut, From, IntoIterator}; use std::collections::{btree_set, BTreeSet, HashMap, HashSet}; use std::ops::Deref; +use std::sync::Arc; use std::time::Instant; use crate::db::db_metrics::{DB_METRICS, MAX_QUERY_CPU_TIME}; @@ -681,7 +682,7 @@ impl<'a> IncrementalJoin<'a> { /// Replace an [IndexJoin]'s scan or fetch operation with a delta table. /// A delta table consists purely of updates or changes to the base table. fn with_delta_table(mut join: IndexJoin, index_side: bool, delta: DatabaseTableUpdate) -> IndexJoin { - fn to_mem_table(head: Header, table_access: StAccess, delta: DatabaseTableUpdate) -> MemTable { + fn to_mem_table(head: Arc
, table_access: StAccess, delta: DatabaseTableUpdate) -> MemTable { MemTable::new( head, table_access, diff --git a/crates/core/src/vm.rs b/crates/core/src/vm.rs index 22f1c50e532..bcaf73ab88a 100644 --- a/crates/core/src/vm.rs +++ b/crates/core/src/vm.rs @@ -19,6 +19,7 @@ use spacetimedb_vm::iterators::RelIter; use spacetimedb_vm::program::ProgramVm; use spacetimedb_vm::rel_ops::RelOps; use spacetimedb_vm::relation::{MemTable, RelValue, Table}; +use std::sync::Arc; pub enum TxMode<'a> { MutTx(&'a mut MutTx), @@ -155,7 +156,7 @@ fn join_inner<'a>( let header = if semi { col_lhs_header.clone() } else { - col_lhs_header.extend(&col_rhs_header) + Arc::new(col_lhs_header.extend(&col_rhs_header)) }; lhs.join_inner( @@ -221,7 +222,7 @@ pub struct IndexSemiJoin<'a, Rhs: RelOps<'a>> { // The field whose value will be used to probe the index. pub probe_field: FieldName, // The header for the index side of the join. - pub index_header: Header, + pub index_header: Arc
, // An optional predicate to evaluate over the matching rows of the index. pub index_select: Option, // The table id on which the index is defined. @@ -261,7 +262,7 @@ impl<'a, Rhs: RelOps<'a>> IndexSemiJoin<'a, Rhs> { } impl<'a, Rhs: RelOps<'a>> RelOps<'a> for IndexSemiJoin<'a, Rhs> { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ if self.return_index_rows { &self.index_header } else { @@ -502,7 +503,7 @@ impl ProgramVm for DbProgram<'_, '_> { } impl<'a> RelOps<'a> for TableCursor<'a> { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.table.head } @@ -516,7 +517,7 @@ impl<'a> RelOps<'a> for TableCursor<'a> { } impl<'a, R: RangeBounds> RelOps<'a> for IndexCursor<'a, R> { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.table.head } @@ -533,7 +534,7 @@ impl<'a, I> RelOps<'a> for CatalogCursor where I: Iterator, { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.table.head } @@ -663,7 +664,7 @@ pub(crate) mod tests { let result = run_ast(p, q.into()); //The expected result - let input = mem_table(schema.head, vec![row]); + let input = mem_table(schema.head.clone_for_error(), vec![row]); assert_eq!(result, Code::Table(input), "{}", name); } diff --git a/crates/sats/src/db/def.rs b/crates/sats/src/db/def.rs index 935298d68c2..78c7fdb5b68 100644 --- a/crates/sats/src/db/def.rs +++ b/crates/sats/src/db/def.rs @@ -1,6 +1,7 @@ use derive_more::Display; use itertools::Itertools; use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use crate::db::auth::{StAccess, StTableType}; use crate::db::error::{DefType, SchemaError}; @@ -938,7 +939,12 @@ impl From<&TableSchema> for ProductType { impl From<&TableSchema> for DbTable { fn from(value: &TableSchema) -> Self { - DbTable::new(value.into(), value.table_id, value.table_type, value.table_access) + DbTable::new( + Arc::new(value.into()), + value.table_id, + value.table_type, + value.table_access, + ) } } diff --git a/crates/sats/src/relation.rs b/crates/sats/src/relation.rs index df6cc1478c0..01e3ead9600 100644 --- a/crates/sats/src/relation.rs +++ b/crates/sats/src/relation.rs @@ -8,6 +8,7 @@ use spacetimedb_primitives::{ColId, ColList, ColListBuilder, Constraints, TableI use std::collections::hash_map::DefaultHasher; use std::fmt; use std::hash::{Hash, Hasher}; +use std::sync::Arc; pub fn calculate_hash(t: &T) -> u64 { let mut s = DefaultHasher::new(); @@ -15,7 +16,7 @@ pub fn calculate_hash(t: &T) -> u64 { s.finish() } -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)] pub struct TableField<'a> { pub table: Option<&'a str>, pub field: &'a str, @@ -34,7 +35,7 @@ pub fn extract_table_field(ident: &str) -> Result { } } -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)] pub enum FieldOnly<'a> { Name(&'a str), Pos(usize), @@ -53,6 +54,7 @@ impl fmt::Display for FieldOnly<'_> { } } +// TODO(perf): Remove `Clone` derivation. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] pub enum FieldName { Name { table: String, field: String }, @@ -101,6 +103,7 @@ impl FieldName { } } +// TODO(perf): Remove `Clone` derivation. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, From)] pub enum FieldExpr { Name(FieldName), @@ -135,12 +138,13 @@ impl fmt::Display for FieldExpr { } } -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct ColumnOnlyField<'a> { pub field: FieldOnly<'a>, pub algebraic_type: &'a AlgebraicType, } +// TODO(perf): Remove `Clone` derivation. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct Column { pub field: FieldName, @@ -165,12 +169,13 @@ impl Column { } } +// TODO(perf): Remove `Clone` impl. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct HeaderOnlyField<'a> { pub fields: Vec>, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Header { pub table_name: String, pub fields: Vec, @@ -186,6 +191,20 @@ impl Header { } } + /// Equivalent to what [`Clone::clone`] would do. + /// + /// `Header` intentionally does not implement `Clone`, + /// as we can't afford to clone it in normal execution paths. + /// However, we don't care about performance in error paths, + /// and we need to embed owned `Header`s in error objects to report useful messages. + pub fn clone_for_error(&self) -> Self { + Header { + table_name: self.table_name.clone(), + fields: self.fields.clone(), + constraints: self.constraints.clone(), + } + } + pub fn from_product_type(table_name: String, fields: ProductType) -> Self { let cols = fields .elements @@ -209,6 +228,14 @@ impl Header { Self::new(table_name, cols, Default::default()) } + pub fn to_product_type(&self) -> ProductType { + ProductType::from_iter( + self.fields.iter().map(|x| { + ProductTypeElement::new(x.algebraic_type.clone(), x.field.field_name().map(ToString::to_string)) + }), + ) + } + pub fn for_mem_table(fields: ProductType) -> Self { let table_name = format!("mem#{:x}", calculate_hash(&fields)); Self::from_product_type(table_name, fields) @@ -246,7 +273,7 @@ impl Header { pub fn column_pos_or_err<'a>(&'a self, col: &'a FieldName) -> Result { self.column_pos(col) - .ok_or_else(|| RelationError::FieldNotFound(self.clone(), col.clone())) + .ok_or_else(|| RelationError::FieldNotFound(self.clone_for_error(), col.clone())) } /// Finds the position of a field with `name`. @@ -419,7 +446,7 @@ impl RowCount { /// A [Relation] is anything that could be represented as a [Header] of `[ColumnName:ColumnType]` that /// generates rows/tuples of [AlgebraicValue] that exactly match that [Header]. pub trait Relation { - fn head(&self) -> &Header; + fn head(&self) -> &Arc
; /// Specify the size in rows of the [Relation]. /// /// Warning: It should at least be precise in the lower-bound estimate. @@ -429,14 +456,14 @@ pub trait Relation { /// A stored table from [RelationalDB] #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct DbTable { - pub head: Header, + pub head: Arc
, pub table_id: TableId, pub table_type: StTableType, pub table_access: StAccess, } impl DbTable { - pub fn new(head: Header, table_id: TableId, table_type: StTableType, table_access: StAccess) -> Self { + pub fn new(head: Arc
, table_id: TableId, table_type: StTableType, table_access: StAccess) -> Self { Self { head, table_id, @@ -447,7 +474,7 @@ impl DbTable { } impl Relation for DbTable { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.head } @@ -489,20 +516,20 @@ mod tests { let head = head("t1", ("a", "b"), 0); let new = head.project(&[] as &[FieldName]).unwrap(); - let mut empty = head.clone(); + let mut empty = head.clone_for_error(); empty.fields.clear(); empty.constraints.clear(); assert_eq!(empty, new); - let all = head.clone(); + let all = head.clone_for_error(); let new = head .project(&[FieldName::named("t1", "a"), FieldName::named("t1", "b")]) .unwrap(); assert_eq!(all, new); - let mut first = head.clone(); + let mut first = head.clone_for_error(); first.fields.pop(); first.constraints = first.retain_constraints(&0.into()); @@ -510,7 +537,7 @@ mod tests { assert_eq!(first, new); - let mut second = head.clone(); + let mut second = head.clone_for_error(); second.fields.remove(0); second.constraints = second.retain_constraints(&1.into()); diff --git a/crates/vm/src/dsl.rs b/crates/vm/src/dsl.rs index 99f2d791904..df4ed1a0de5 100644 --- a/crates/vm/src/dsl.rs +++ b/crates/vm/src/dsl.rs @@ -7,6 +7,7 @@ use spacetimedb_sats::algebraic_value::AlgebraicValue; use spacetimedb_sats::db::auth::{StAccess, StTableType}; use spacetimedb_sats::product_value::ProductValue; use spacetimedb_sats::relation::{DbTable, Header}; +use std::sync::Arc; pub fn scalar>(of: T) -> AlgebraicValue { of.into() @@ -23,7 +24,7 @@ where I: IntoIterator, T: Into, { - MemTable::from_iter(head.into(), iter.into_iter().map(Into::into)) + MemTable::from_iter(Arc::new(head.into()), iter.into_iter().map(Into::into)) } pub fn db_table_raw>( @@ -32,7 +33,7 @@ pub fn db_table_raw>( table_type: StTableType, table_access: StAccess, ) -> DbTable { - DbTable::new(head.into(), table_id, table_type, table_access) + DbTable::new(Arc::new(head.into()), table_id, table_type, table_access) } /// Create a [DbTable] of type [StTableType::User] and derive `StAccess::for_name(name)`. diff --git a/crates/vm/src/eval.rs b/crates/vm/src/eval.rs index 37ea9425874..835cd00cc96 100644 --- a/crates/vm/src/eval.rs +++ b/crates/vm/src/eval.rs @@ -1,13 +1,13 @@ -use crate::dsl::mem_table; +use std::sync::Arc; + use crate::errors::ErrorVm; use crate::expr::{Code, CrudCode, CrudExpr, QueryCode, QueryExpr, SourceExpr}; use crate::expr::{Expr, Query}; use crate::iterators::RelIter; use crate::program::ProgramVm; use crate::rel_ops::RelOps; -use crate::relation::{MemTable, RelValue, Table}; +use crate::relation::{RelValue, Table}; use spacetimedb_sats::relation::{FieldExpr, Relation}; -use spacetimedb_sats::{product, AlgebraicType, ProductType}; fn compile_query(q: QueryExpr) -> QueryCode { match q.source { @@ -113,7 +113,7 @@ pub fn build_query<'a>(mut result: Box>, query: Vec) -> Resu let iter = lhs.join_inner( rhs, - col_lhs_header.extend(&col_rhs_header), + Arc::new(col_lhs_header.extend(&col_rhs_header)), move |row| Ok(row.get(&key_lhs, &key_lhs_header)?.into_owned().into()), move |row| Ok(row.get(&key_rhs, &key_rhs_header)?.into_owned().into()), move |l, r| { @@ -189,44 +189,49 @@ pub fn run_ast(p: &mut P, ast: Expr) -> Code { eval(p, code) } -// Used internally for testing SQL JOINS -#[doc(hidden)] -pub struct GameData { - pub location: MemTable, - pub inv: MemTable, - pub player: MemTable, -} -// Used internally for testing SQL JOINS +/// Used internally for testing SQL JOINS. #[doc(hidden)] -pub fn create_game_data() -> GameData { - let head = ProductType::from([("inventory_id", AlgebraicType::U64), ("name", AlgebraicType::String)]); - let row = product!(1u64, "health"); - let inv = mem_table(head, [row]); - - let head = ProductType::from([("entity_id", AlgebraicType::U64), ("inventory_id", AlgebraicType::U64)]); - let row1 = product!(100u64, 1u64); - let row2 = product!(200u64, 1u64); - let row3 = product!(300u64, 1u64); - let player = mem_table(head, [row1, row2, row3]); - - let head = ProductType::from([ - ("entity_id", AlgebraicType::U64), - ("x", AlgebraicType::F32), - ("z", AlgebraicType::F32), - ]); - let row1 = product!(100u64, 0.0f32, 32.0f32); - let row2 = product!(100u64, 1.0f32, 31.0f32); - let location = mem_table(head, [row1, row2]); - - GameData { location, inv, player } +pub mod test_data { + use crate::{dsl::mem_table, relation::MemTable}; + use spacetimedb_sats::{product, AlgebraicType, ProductType}; + + pub struct GameData { + pub location: MemTable, + pub inv: MemTable, + pub player: MemTable, + } + + pub fn create_game_data() -> GameData { + let head = ProductType::from([("inventory_id", AlgebraicType::U64), ("name", AlgebraicType::String)]); + let row = product!(1u64, "health"); + let inv = mem_table(head, [row]); + + let head = ProductType::from([("entity_id", AlgebraicType::U64), ("inventory_id", AlgebraicType::U64)]); + let row1 = product!(100u64, 1u64); + let row2 = product!(200u64, 1u64); + let row3 = product!(300u64, 1u64); + let player = mem_table(head, [row1, row2, row3]); + + let head = ProductType::from([ + ("entity_id", AlgebraicType::U64), + ("x", AlgebraicType::F32), + ("z", AlgebraicType::F32), + ]); + let row1 = product!(100u64, 0.0f32, 32.0f32); + let row2 = product!(100u64, 1.0f32, 31.0f32); + let location = mem_table(head, [row1, row2]); + + GameData { location, inv, player } + } } #[cfg(test)] -mod tests { +pub mod tests { #![allow(clippy::disallowed_macros)] + use super::test_data::*; use super::*; - use crate::dsl::{query, scalar}; + use crate::dsl::{mem_table, query, scalar}; use crate::program::Program; use crate::relation::MemTable; use spacetimedb_lib::identity::AuthCtx; @@ -234,6 +239,7 @@ mod tests { use spacetimedb_sats::db::auth::StAccess; use spacetimedb_sats::db::error::RelationError; use spacetimedb_sats::relation::FieldName; + use spacetimedb_sats::{product, AlgebraicType, ProductType}; fn run_query(p: &mut Program, ast: Expr) -> MemTable { match run_ast(p, ast) { @@ -286,7 +292,7 @@ mod tests { let result = run_ast(p, q.into()); assert_eq!( result, - Code::Halt(RelationError::FieldNotFound(head, field).into()), + Code::Halt(RelationError::FieldNotFound(head.clone_for_error(), field).into()), "Bad Project" ); } diff --git a/crates/vm/src/expr.rs b/crates/vm/src/expr.rs index a59e42ac4fa..ec3685c13cd 100644 --- a/crates/vm/src/expr.rs +++ b/crates/vm/src/expr.rs @@ -17,6 +17,7 @@ use std::collections::{HashMap, VecDeque}; use std::fmt; use std::hash::{Hash, Hasher}; use std::ops::Bound; +use std::sync::Arc; /// Trait for checking if the `caller` have access to `Self` pub trait AuthAccess { @@ -297,7 +298,7 @@ impl SourceExpr { } } - pub fn head(&self) -> &Header { + pub fn head(&self) -> &Arc
{ match self { SourceExpr::MemTable(x) => &x.head, SourceExpr::DbTable(x) => &x.head, @@ -313,7 +314,7 @@ impl SourceExpr { } impl Relation for SourceExpr { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ match self { SourceExpr::MemTable(x) => x.head(), SourceExpr::DbTable(x) => x.head(), @@ -349,7 +350,7 @@ impl From for Table { impl From<&TableSchema> for SourceExpr { fn from(value: &TableSchema) -> Self { SourceExpr::DbTable(DbTable::new( - value.into(), + Arc::new(value.into()), value.table_id, value.table_type, value.table_access, @@ -1565,7 +1566,7 @@ impl AuthAccess for QueryCode { } impl Relation for QueryCode { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ self.table.head() } @@ -1681,20 +1682,20 @@ mod tests { fn tables() -> [Table; 2] { [ Table::MemTable(MemTable { - head: Header { + head: Arc::new(Header { table_name: "foo".into(), fields: vec![], constraints: Default::default(), - }, + }), data: vec![], table_access: StAccess::Private, }), Table::DbTable(DbTable { - head: Header { + head: Arc::new(Header { table_name: "foo".into(), fields: vec![], constraints: vec![(ColId(42).into(), Constraints::indexed())], - }, + }), table_id: 42.into(), table_type: StTableType::User, table_access: StAccess::Private, @@ -1722,11 +1723,11 @@ mod tests { field: "bar".into(), }, index_side: Table::DbTable(DbTable { - head: Header { + head: Arc::new(Header { table_name: "bar".into(), fields: vec![], constraints: Default::default(), - }, + }), table_id: 42.into(), table_type: StTableType::User, table_access: StAccess::Public, @@ -1792,7 +1793,7 @@ mod tests { .collect(), ); MemTable { - head, + head: Arc::new(head), data, table_access, } diff --git a/crates/vm/src/iterators.rs b/crates/vm/src/iterators.rs index c3506ae0fb8..bb8edd14928 100644 --- a/crates/vm/src/iterators.rs +++ b/crates/vm/src/iterators.rs @@ -3,18 +3,19 @@ use crate::rel_ops::RelOps; use crate::relation::{MemTable, RelValue}; use core::mem; use spacetimedb_sats::relation::{Header, RowCount}; +use std::sync::Arc; /// Common wrapper for relational iterators that work like cursors. #[derive(Debug)] pub struct RelIter { - pub head: Header, + pub head: Arc
, pub row_count: RowCount, pub pos: usize, pub of: T, } impl RelIter { - pub fn new(head: Header, row_count: RowCount, of: T) -> Self { + pub fn new(head: Arc
, row_count: RowCount, of: T) -> Self { Self { head, row_count, @@ -25,7 +26,7 @@ impl RelIter { } impl<'a> RelOps<'a> for RelIter { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.head } diff --git a/crates/vm/src/rel_ops.rs b/crates/vm/src/rel_ops.rs index e18b5775d75..f287b3ebb4b 100644 --- a/crates/vm/src/rel_ops.rs +++ b/crates/vm/src/rel_ops.rs @@ -3,6 +3,7 @@ use crate::relation::RelValue; use spacetimedb_sats::product_value::ProductValue; use spacetimedb_sats::relation::{FieldExpr, Header, RowCount}; use std::collections::HashMap; +use std::sync::Arc; pub(crate) trait ResultExt { fn unpack_fold(self) -> Result; @@ -10,7 +11,7 @@ pub(crate) trait ResultExt { /// A trait for dealing with fallible iterators for the database. pub trait RelOps<'a> { - fn head(&self) -> &Header; + fn head(&self) -> &Arc
; fn row_count(&self) -> RowCount; /// Advances the `iterator` and returns the next [RelValue]. fn next(&mut self) -> Result>, ErrorVm>; @@ -67,7 +68,7 @@ pub trait RelOps<'a> { { let count = self.row_count(); let head = self.head().project(&cols)?; - Ok(Project::new(self, count, head, cols, extractor)) + Ok(Project::new(self, count, Arc::new(head), cols, extractor)) } /// Intersection between the left and the right, both (non-sorted) `iterators`. @@ -84,7 +85,7 @@ pub trait RelOps<'a> { fn join_inner( self, with: Rhs, - head: Header, + head: Arc
, key_lhs: KeyLhs, key_rhs: KeyRhs, predicate: Pred, @@ -120,7 +121,7 @@ pub trait RelOps<'a> { } impl<'a, I: RelOps<'a> + ?Sized> RelOps<'a> for Box { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ (**self).head() } @@ -135,14 +136,14 @@ impl<'a, I: RelOps<'a> + ?Sized> RelOps<'a> for Box { #[derive(Clone, Debug)] pub struct Select { - pub(crate) head: Header, + pub(crate) head: Arc
, pub(crate) count: RowCount, pub(crate) iter: I, pub(crate) predicate: P, } impl Select { - pub fn new(iter: I, count: RowCount, head: Header, predicate: P) -> Select { + pub fn new(iter: I, count: RowCount, head: Arc
, predicate: P) -> Select { Select { iter, count, @@ -157,7 +158,7 @@ where I: RelOps<'a>, P: FnMut(&RelValue<'a>) -> Result, { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.head } @@ -178,7 +179,7 @@ where #[derive(Clone, Debug)] pub struct Project { - pub(crate) head: Header, + pub(crate) head: Arc
, pub(crate) count: RowCount, pub(crate) cols: Vec, pub(crate) iter: I, @@ -186,7 +187,7 @@ pub struct Project { } impl Project { - pub fn new(iter: I, count: RowCount, head: Header, cols: Vec, extractor: P) -> Project { + pub fn new(iter: I, count: RowCount, head: Arc
, cols: Vec, extractor: P) -> Project { Project { iter, count, @@ -202,7 +203,7 @@ where I: RelOps<'a>, P: FnMut(&[FieldExpr], RelValue<'a>) -> Result, ErrorVm>, { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.head } @@ -221,7 +222,7 @@ where #[derive(Clone, Debug)] pub struct JoinInner<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> { - pub(crate) head: Header, + pub(crate) head: Arc
, pub(crate) count: RowCount, pub(crate) lhs: Lhs, pub(crate) rhs: Rhs, @@ -236,7 +237,7 @@ pub struct JoinInner<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> { impl<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> JoinInner<'a, Lhs, Rhs, KeyLhs, KeyRhs, Pred, Proj> { pub fn new( - head: Header, + head: Arc
, lhs: Lhs, rhs: Rhs, key_lhs: KeyLhs, @@ -270,7 +271,7 @@ where Pred: FnMut(&RelValue<'a>, &RelValue<'a>) -> Result, Proj: FnMut(RelValue<'a>, RelValue<'a>) -> RelValue<'a>, { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.head } diff --git a/crates/vm/src/relation.rs b/crates/vm/src/relation.rs index ed2acfa703e..20240d56ee9 100644 --- a/crates/vm/src/relation.rs +++ b/crates/vm/src/relation.rs @@ -9,6 +9,7 @@ use spacetimedb_table::table::RowRef; use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::mem; +use std::sync::Arc; /// RelValue represents either a reference to a row in a table, /// or an ephemeral row constructed during query execution. @@ -116,15 +117,16 @@ pub struct MemTableWithoutTableName<'a> { } /// An in-memory table +// TODO(perf): Remove `Clone` impl. #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)] pub struct MemTable { - pub head: Header, + pub head: Arc
, pub data: Vec, pub table_access: StAccess, } impl MemTable { - pub fn new(head: Header, table_access: StAccess, data: Vec) -> Self { + pub fn new(head: Arc
, table_access: StAccess, data: Vec) -> Self { assert_eq!( head.fields.len(), data.first() @@ -141,10 +143,10 @@ impl MemTable { pub fn from_value(of: AlgebraicValue) -> Self { let head = Header::for_mem_table(of.type_of().into()); - Self::new(head, StAccess::Public, [of.into()].into()) + Self::new(Arc::new(head), StAccess::Public, [of.into()].into()) } - pub fn from_iter(head: Header, data: impl Iterator) -> Self { + pub fn from_iter(head: Arc
, data: impl Iterator) -> Self { Self { head, data: data.collect(), @@ -169,7 +171,7 @@ impl MemTable { } impl Relation for MemTable { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ &self.head } @@ -234,7 +236,7 @@ impl Table { } impl Relation for Table { - fn head(&self) -> &Header { + fn head(&self) -> &Arc
{ match self { Table::MemTable(x) => x.head(), Table::DbTable(x) => x.head(),