diff --git a/internal/engine/cost/datasource/datasource.go b/internal/engine/cost/datasource/datasource.go index 0f309a5b7..23f5a9753 100644 --- a/internal/engine/cost/datasource/datasource.go +++ b/internal/engine/cost/datasource/datasource.go @@ -7,6 +7,8 @@ import ( "os" "strconv" "strings" + + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" ) // ColumnValue @@ -57,7 +59,7 @@ func newRowPipeline(rows []Row) RowPipeline { } type Result struct { - schema *Schema + schema *datatypes.Schema stream RowPipeline } @@ -88,16 +90,16 @@ func (r *Result) ToCsv() string { return sb.String() } -func ResultFromStream(s *Schema, rows RowPipeline) *Result { +func ResultFromStream(s *datatypes.Schema, rows RowPipeline) *Result { return &Result{schema: s, stream: rows} } -func ResultFromRaw(s *Schema, rows []Row) *Result { +func ResultFromRaw(s *datatypes.Schema, rows []Row) *Result { // TODO: use RowPipeline all the way return &Result{schema: s, stream: newRowPipeline(rows)} } -func (r *Result) Schema() *Schema { +func (r *Result) Schema() *datatypes.Schema { return r.schema } @@ -111,7 +113,7 @@ type SourceType string // DataSource represents a data source. type DataSource interface { // Schema returns the schema for the underlying data source - Schema() *Schema + Schema() *datatypes.Schema // SourceType returns the type of the data source. SourceType() SourceType @@ -129,7 +131,7 @@ type DataSource interface { // dsScan read the data source, return selected columns. // TODO: use channel to return the result, e.g. iterator model. -func dsScan(dsSchema *Schema, dsRecords []Row, projection []string) *Result { +func dsScan(dsSchema *datatypes.Schema, dsRecords []Row, projection []string) *Result { if len(projection) == 0 { return ResultFromRaw(dsSchema, dsRecords) } @@ -148,14 +150,14 @@ func dsScan(dsSchema *Schema, dsRecords []Row, projection []string) *Result { // } //} - fieldIndex := dsSchema.mapProjection(projection) + fieldIndex := dsSchema.MapProjection(projection) - newFields := make([]Field, len(projection)) + newFields := make([]datatypes.Field, len(projection)) for i, idx := range fieldIndex { newFields[i] = dsSchema.Fields[idx] } - newSchema := NewSchema(newFields...) + newSchema := datatypes.NewSchema(newFields...) out := make(RowPipeline) go func() { @@ -175,15 +177,15 @@ func dsScan(dsSchema *Schema, dsRecords []Row, projection []string) *Result { // memDataSource is a data source that reads data from memory. type memDataSource struct { - schema *Schema + schema *datatypes.Schema records []Row } -func NewMemDataSource(s *Schema, data []Row) *memDataSource { +func NewMemDataSource(s *datatypes.Schema, data []Row) *memDataSource { return &memDataSource{schema: s, records: data} } -func (ds *memDataSource) Schema() *Schema { +func (ds *memDataSource) Schema() *datatypes.Schema { return ds.schema } @@ -199,11 +201,11 @@ func (ds *memDataSource) SourceType() SourceType { type csvDataSource struct { path string records []Row - schema *Schema + schema *datatypes.Schema } func NewCSVDataSource(path string) (*csvDataSource, error) { - ds := &csvDataSource{path: path, schema: &Schema{}} + ds := &csvDataSource{path: path, schema: &datatypes.Schema{}} if err := ds.load(); err != nil { return nil, err } @@ -270,13 +272,13 @@ func (ds *csvDataSource) load() error { for i, name := range header { ds.schema.Fields = append(ds.schema.Fields, - Field{Name: name, Type: columnTypes[i]}) + datatypes.Field{Name: name, Type: columnTypes[i]}) } return nil } -func (ds *csvDataSource) Schema() *Schema { +func (ds *csvDataSource) Schema() *datatypes.Schema { return ds.schema } diff --git a/internal/engine/cost/datasource/datasource_test.go b/internal/engine/cost/datasource/datasource_test.go index f348136dc..b14bcafeb 100644 --- a/internal/engine/cost/datasource/datasource_test.go +++ b/internal/engine/cost/datasource/datasource_test.go @@ -2,30 +2,31 @@ package datasource import ( "fmt" + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" "testing" "github.com/stretchr/testify/assert" ) // testSchemaUsers is the same as first line of ../../testdata/users.csv -var testSchemaUsers = NewSchema( - Field{ +var testSchemaUsers = datatypes.NewSchema( + datatypes.Field{ Name: "id", Type: "int", }, - Field{ + datatypes.Field{ Name: "username", Type: "string", }, - Field{ + datatypes.Field{ Name: "age", Type: "int", }, - Field{ + datatypes.Field{ Name: "state", Type: "string", }, - Field{ + datatypes.Field{ Name: "wallet", Type: "string", }, @@ -70,7 +71,7 @@ var testDataUsers = []Row{ }, } -func checkRecords(t *testing.T, result *Result, expectedSchema *Schema, expectedData []Row) { +func checkRecords(t *testing.T, result *Result, expectedSchema *datatypes.Schema, expectedData []Row) { t.Helper() s := result.Schema() @@ -113,12 +114,12 @@ func TestMemDataSource_scanWithProjection(t *testing.T) { ds := NewMemDataSource(testSchemaUsers, testDataUsers) // Test filtered result - expectedSchema := NewSchema( - Field{ + expectedSchema := datatypes.NewSchema( + datatypes.Field{ Name: "username", Type: "string", }, - Field{ + datatypes.Field{ Name: "age", Type: "int", }) @@ -191,16 +192,16 @@ func TestCSVDataSource_scanWithProjection(t *testing.T) { assert.NoError(t, err) // Test filtered result - expectedSchema := NewSchema( - Field{ + expectedSchema := datatypes.NewSchema( + datatypes.Field{ Name: "id", Type: "int", }, - Field{ + datatypes.Field{ Name: "username", Type: "string", }, - Field{ + datatypes.Field{ Name: "state", Type: "string", }, diff --git a/internal/engine/cost/datasource/schema.go b/internal/engine/cost/datasource/schema.go deleted file mode 100644 index 21fcfe73c..000000000 --- a/internal/engine/cost/datasource/schema.go +++ /dev/null @@ -1,61 +0,0 @@ -package datasource - -import ( - "fmt" - "strings" -) - -// Field represents a field in a schema. -type Field struct { - Name string - Type string -} - -type Schema struct { - Fields []Field -} - -func (s *Schema) String() string { - var fields []string - for _, f := range s.Fields { - fields = append(fields, fmt.Sprintf("%s/%s", f.Name, f.Type)) - } - return fmt.Sprintf("[%s]", strings.Join(fields, ", ")) -} - -func (s *Schema) Select(projection ...string) *Schema { - fieldIndex := s.mapProjection(projection) - - newFields := make([]Field, len(projection)) - for i, idx := range fieldIndex { - newFields[i] = s.Fields[idx] - } - - return NewSchema(newFields...) -} - -// mapProjection maps the projection to the index of the fields in the schema. -func (s *Schema) mapProjection(projection []string) []int { - fieldIndexMap := make(map[string]int) - for i, field := range s.Fields { - fieldIndexMap[field.Name] = i - } - - newFieldsIndex := make([]int, len(projection)) - for i, name := range projection { - newFieldsIndex[i] = fieldIndexMap[name] - } - - return newFieldsIndex -} - -func (s *Schema) Join(other *Schema) *Schema { - fields := make([]Field, len(s.Fields)+len(other.Fields)) - copy(fields, s.Fields) - copy(fields[len(s.Fields):], other.Fields) - return NewSchema(fields...) -} - -func NewSchema(fields ...Field) *Schema { - return &Schema{Fields: fields} -} diff --git a/internal/engine/cost/datatypes/column.go b/internal/engine/cost/datatypes/column.go new file mode 100644 index 000000000..efd710209 --- /dev/null +++ b/internal/engine/cost/datatypes/column.go @@ -0,0 +1,14 @@ +package datatypes + +type ColumnDef struct { + Relation *TableRef + Name string +} + +func ColumnUnqualified(name string) *ColumnDef { + return &ColumnDef{Name: name} +} + +func Column(table *TableRef, name string) *ColumnDef { + return &ColumnDef{Relation: table, Name: name} +} diff --git a/internal/engine/cost/datatypes/schema.go b/internal/engine/cost/datatypes/schema.go new file mode 100644 index 000000000..cb66d76b7 --- /dev/null +++ b/internal/engine/cost/datatypes/schema.go @@ -0,0 +1,239 @@ +package datatypes + +import ( + "fmt" + "slices" + "strings" +) + + +type TableRef struct { + //DB string + Schema string + Table string +} + +func TableRefFromTable(table string) *TableRef { + return &TableRef{Table: table} +} + +func TableRefFromSchemaAndTable(schema, table string) *TableRef { + return &TableRef{Schema: schema, Table: table} +} + +// Match checks if the given table reference matches the current table reference. +// Not set fields are ignored, meaning it's optimistic to assume equal. +func (t *TableRef) Match(other *TableRef) bool { + if t.Schema != "" { + return t.Schema == other.Schema && t.Table == other.Table + } else { + return t.Table == other.Table + } +} + +// OfRelation is an interface that represents an object that is part of a relation. +type OfRelation interface { + Relation() *TableRef +} + +//type ofRelationBase struct { +// Relation *TableRef +//} +// +//func (o *ofRelationBase) Relation() *TableRef { +// return o.Relation +//} + +// Field represents a field in a schema. +type Field struct { + // ofRelationBase is used to implement the OfRelation interface. + //ofRelationBase + relation *TableRef + + Name string + Type string +} + +func NewField(name, typ string) Field { + return Field{Name: name, Type: typ} +} + +func NewFieldWithRelation(name, typ string, relation *TableRef) Field { + return Field{Name: name, Type: typ, relation: relation} +} + +func (f *Field) Relation() *TableRef { + return f.relation +} + +func (f *Field) QualifiedColumn() *ColumnDef { + return Column(f.relation, f.Name) +} + +type Schema struct { + Fields []Field +} + +func NewSchema(fields ...Field) *Schema { + return &Schema{Fields: fields} +} + +func (s *Schema) String() string { + var fields []string + for _, f := range s.Fields { + fields = append(fields, fmt.Sprintf("%s/%s", f.Name, f.Type)) + } + return fmt.Sprintf("[%s]", strings.Join(fields, ", ")) +} + +func (s *Schema) Select(projection ...string) *Schema { + fieldIndex := s.MapProjection(projection) + + newFields := make([]Field, len(projection)) + for i, idx := range fieldIndex { + newFields[i] = s.Fields[idx] + } + + return NewSchema(newFields...) +} + +// MapProjection maps the projection to the index of the fields in the schema. +// NOTE: originally it's not exported, should come back to this later. +func (s *Schema) MapProjection(projection []string) []int { + fieldIndexMap := make(map[string]int) + for i, field := range s.Fields { + fieldIndexMap[field.Name] = i + } + + newFieldsIndex := make([]int, len(projection)) + for i, name := range projection { + newFieldsIndex[i] = fieldIndexMap[name] + } + + return newFieldsIndex +} + +// Join creates a new schema by joining the fields of the current schema with +// the fields of another schema. +// NOTE: should do this on clone of the schema. +func (s *Schema) Join(other *Schema) *Schema { + fields := make([]Field, len(s.Fields)+len(other.Fields)) + copy(fields, s.Fields) + copy(fields[len(s.Fields):], other.Fields) + return NewSchema(fields...) +} + +func (s *Schema) indexOfField(relation *TableRef, name string) int { + for i, f := range s.Fields { + if relation != nil { // the field to look for is qualified + if f.Relation() != nil { // current field is qualified + if f.Relation().Match(relation) && f.Name == name { + return i + } + } + //else { // current field is unqualified + // + //} + } else { // the field to look for is unqualified + if f.Name == name { + return i + } + } + } + return -1 +} + +func (s *Schema) fieldByQualifiedName(relation *TableRef, name string) *Field { + idx := s.indexOfField(relation, name) + if idx == -1 { + panic(fmt.Sprintf("field %s.%s not found", relation.Table, name)) + //return nil + } + return &s.Fields[idx] +} + +func (s *Schema) fieldByUnqualifiedName(name string) *Field { + var found []*Field + for _, f := range s.Fields { + if f.Name == name { + found = append(found, &f) + } + } + + switch len(found) { + case 0: + panic(fmt.Sprintf("field %s not found", name)) + case 1: + return found[0] + default: + // the field without relation is the one we want + for _, f := range found { + if f.Relation() == nil { + return f + } + } + panic(fmt.Sprintf("ambiguous field %s", name)) + } +} + +func (s *Schema) FieldFromColumn(column *ColumnDef) *Field { + if column.Relation == nil { + return s.fieldByUnqualifiedName(column.Name) + } + return s.fieldByQualifiedName(column.Relation, column.Name) +} + +// Merge modifies the current schema by merging it with another schema, any +// duplicate fields will be ignored. +// NOTE: should do this on clone of the schema. +func (s *Schema) Merge(other *Schema) *Schema { + for _, f := range other.Fields { + //duplicated := false + //if f.Relation() != nil { + // duplicated = s.ContainsQualifiedColumn(f.Relation(), f.Name) + //} else { + // duplicated = s.ContainsUnqualifiedColumn(f.Name) + //} + + duplicated := s.ContainsColumn(f.Relation(), f.Name) + if !duplicated { + s.Fields = append(s.Fields, f) + } + } + + return s +} + +func (s *Schema) ContainsUnqualifiedColumn(name string) bool { + return slices.ContainsFunc(s.Fields, func(field Field) bool { + return field.Name == name + }) +} + +func (s *Schema) ContainsQualifiedColumn(relation *TableRef, name string) bool { + return slices.ContainsFunc(s.Fields, func(field Field) bool { + return field.Relation() == relation && field.Name == name + }) +} + +// ContainsColumn checks if the schema contains the given column. +// It dispatches to ContainsQualifiedColumn or ContainsUnqualifiedColumn based +// on if the relation of the column is set. +func (s *Schema) ContainsColumn(relation *TableRef, name string) bool { + if relation == nil { + return s.ContainsUnqualifiedColumn(name) + } + return s.ContainsQualifiedColumn(relation, name) +} + +// +//func (s *Schema) FieldFromColumn(column *ColumnDef) *Field { +// if column.Relation == nil { +// return s.FieldByName(column.Name) +// } +// return s.FieldByRelationAndName(column.Relation, column.Name) +//} + +func (s *Schema) Clone() *Schema { + return NewSchema(slices.Clone(s.Fields)...) //shallow clone +} diff --git a/internal/engine/cost/logical_plan/accept.go b/internal/engine/cost/logical_plan/accept.go new file mode 100644 index 000000000..18181b6ef --- /dev/null +++ b/internal/engine/cost/logical_plan/accept.go @@ -0,0 +1 @@ +package logical_plan diff --git a/internal/engine/cost/logical_plan/builder.go b/internal/engine/cost/logical_plan/builder.go index 88422b244..3737aa1eb 100644 --- a/internal/engine/cost/logical_plan/builder.go +++ b/internal/engine/cost/logical_plan/builder.go @@ -15,7 +15,7 @@ func newLogicalPlanBuilder() *logicalPlanBuilder { // NoRelation creates a new logicalPlanBuilder with no relation(from). func (b *logicalPlanBuilder) NoRelation() *logicalPlanBuilder { - return &logicalPlanBuilder{} + return &logicalPlanBuilder{plan: NoSource()} } // From creates a new logicalPlanBuilder with a logical plan. @@ -27,11 +27,13 @@ func (b *logicalPlanBuilder) JoinOn(_type string, right LogicalPlan, on LogicalE return b } +// Select applies a projection to the logical plan. func (b *logicalPlanBuilder) Select(exprs ...LogicalExpr) *logicalPlanBuilder { b.plan = Projection(b.plan, exprs...) return b } +// Limit applies LIMIT clause to the logical plan. func (b *logicalPlanBuilder) Limit(offset, limit int) *logicalPlanBuilder { b.plan = Limit(b.plan, offset, limit) return b @@ -65,6 +67,13 @@ func (b *logicalPlanBuilder) Except(right LogicalPlan) *logicalPlanBuilder { return b } +func (b *logicalPlanBuilder) Aggregate(keys []LogicalExpr, aggregates []LogicalExpr) *logicalPlanBuilder { + keys = NormalizeExprs(keys, b.plan) + aggregates = NormalizeExprs(aggregates, b.plan) + b.plan = Aggregate(b.plan, keys, aggregates) + return b +} + func (b *logicalPlanBuilder) Build() LogicalPlan { return b.plan } diff --git a/internal/engine/cost/logical_plan/expr.go b/internal/engine/cost/logical_plan/expr.go index f52330983..3109816a5 100644 --- a/internal/engine/cost/logical_plan/expr.go +++ b/internal/engine/cost/logical_plan/expr.go @@ -4,17 +4,17 @@ import ( "fmt" "strings" - "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" + dt "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" + pt "github.com/kwilteam/kwil-db/internal/engine/cost/plantree" ) // LogicalExpr represents the strategies to access the required data. // It's different from tree.Expression in that it will be used to access the data. type LogicalExpr interface { - fmt.Stringer + pt.ExprNode - // Resolve returns the field that this expression represents from the input - // logical plan. - Resolve(LogicalPlan) datasource.Field + // Resolve returns the field that this expression represents from the schema + Resolve(*dt.Schema) dt.Field } type LogicalExprList []LogicalExpr @@ -30,91 +30,150 @@ func (e LogicalExprList) String() string { // ColumnExpr represents a column in a schema. // NOTE: it will be transformed to columnIdxExpr in the logical plan.???? type ColumnExpr struct { - Table string - Name string + *pt.BaseTreeNode + + Relation *dt.TableRef + Name string +} + +var _ LogicalExpr = &ColumnExpr{} + +func (e *ColumnExpr) String() string { + return e.Name } -func (c *ColumnExpr) String() string { - return c.Name +func (e *ColumnExpr) Resolve(schema *dt.Schema) dt.Field { + // TODO: use just one Column definition, right now we have: + // - ColumnExpr + // - dt.ColumnDef, to avoid circular import + return *schema.FieldFromColumn(dt.Column(e.Relation, e.Name)) } -func (c *ColumnExpr) Resolve(plan LogicalPlan) datasource.Field { - for _, field := range plan.Schema().Fields { - if field.Name == c.Name { - return field +// QualifyWithSchemas returns a new ColumnExpr with the relation set, i.e. qualified. +// NOTE: +// This feels like `Resolve`, but more coupled with implementation details. +// TODO: use all input's schemas as backup schemas +func (e *ColumnExpr) QualifyWithSchemas(schemas ...*dt.Schema) *ColumnExpr { + if e.Relation != nil { + return e + } + + var schemaToUse *dt.Schema + for _, schema := range schemas { + var matchedFields []dt.Field + for _, field := range schema.Fields { + if field.Name == e.Name { + matchedFields = append(matchedFields, field) + } + } + + switch len(matchedFields) { + case 0: + continue + case 1: + schemaToUse = schema + break + default: + // handle ambiguous column, e.g. same column name in different tables + // This can only happen when Join with USING clause, kwil doesn't support it yet. + panic(fmt.Sprintf("cannot qualify ambiguous column: %s", e.Name)) } } - panic(fmt.Sprintf("field %s not found", c.Name)) + + if schemaToUse == nil { + panic(fmt.Sprintf("field %s not found", e.Name)) + } + + field := e.Resolve(schemaToUse) + + return &ColumnExpr{ + BaseTreeNode: pt.NewBaseTreeNode(), + Relation: field.Relation(), + Name: field.Name, + } +} + +func ColumnUnqualified(name string) *ColumnExpr { + return &ColumnExpr{BaseTreeNode: pt.NewBaseTreeNode(), Name: name} } -func Column(table, name string) LogicalExpr { - return &ColumnExpr{Table: table, Name: name} +func Column(table *dt.TableRef, name string) *ColumnExpr { + return &ColumnExpr{BaseTreeNode: pt.NewBaseTreeNode(), Relation: table, Name: name} } // ColumnIdxExpr represents a column in a schema by its index. type ColumnIdxExpr struct { + *pt.BaseTreeNode + Idx int } -func (c *ColumnIdxExpr) String() string { - return fmt.Sprintf("$%d", c.Idx) +func (e *ColumnIdxExpr) String() string { + return fmt.Sprintf("$%d", e.Idx) } -func (c *ColumnIdxExpr) Resolve(plan LogicalPlan) datasource.Field { - return plan.Schema().Fields[c.Idx] +func (e *ColumnIdxExpr) Resolve(schema *dt.Schema) dt.Field { + return schema.Fields[e.Idx] } func ColumnIdx(idx int) LogicalExpr { - return &ColumnIdxExpr{Idx: idx} + return &ColumnIdxExpr{BaseTreeNode: pt.NewBaseTreeNode(), Idx: idx} } type AliasExpr struct { + *pt.BaseTreeNode + + // RELATION Expr LogicalExpr Alias string } -func (a *AliasExpr) String() string { - return fmt.Sprintf("%s AS %s", a.Expr, a.Alias) +func (e *AliasExpr) String() string { + return fmt.Sprintf("%s AS %s", e.Expr, e.Alias) } -func (a *AliasExpr) Resolve(plan LogicalPlan) datasource.Field { - return datasource.Field{Name: a.Alias, Type: a.Expr.Resolve(plan).Type} +func (e *AliasExpr) Resolve(schema *dt.Schema) dt.Field { + return dt.Field{Name: e.Alias, Type: e.Expr.Resolve(schema).Type} } -func Alias(expr LogicalExpr, alias string) LogicalExpr { - return &AliasExpr{Expr: expr, Alias: alias} +func Alias(expr LogicalExpr, alias string) *AliasExpr { + return &AliasExpr{BaseTreeNode: pt.NewBaseTreeNode(), Expr: expr, Alias: alias} } type LiteralStringExpr struct { + *pt.BaseTreeNode + Value string } -func (l *LiteralStringExpr) String() string { - return l.Value +func (e *LiteralStringExpr) String() string { + return e.Value } -func (l *LiteralStringExpr) Resolve(LogicalPlan) datasource.Field { - return datasource.Field{Name: l.Value, Type: "text"} +func (e *LiteralStringExpr) Resolve(*dt.Schema) dt.Field { + return dt.Field{Name: e.Value, Type: "text"} } -func LiteralString(value string) LogicalExpr { - return &LiteralStringExpr{Value: value} +func LiteralString(value string) *LiteralStringExpr { + return &LiteralStringExpr{BaseTreeNode: pt.NewBaseTreeNode(), Value: value} } type LiteralIntExpr struct { + *pt.BaseTreeNode + Value int } -func (l *LiteralIntExpr) String() string { - return fmt.Sprintf("%d", l.Value) +func (e *LiteralIntExpr) String() string { + return fmt.Sprintf("%d", e.Value) } -func (l *LiteralIntExpr) Resolve(LogicalPlan) datasource.Field { - return datasource.Field{Name: fmt.Sprintf("%d", l.Value), Type: "int"} +func (e *LiteralIntExpr) Resolve(*dt.Schema) dt.Field { + return dt.Field{Name: fmt.Sprintf("%d", e.Value), Type: "int"} } -func LiteralInt(value int) LogicalExpr { - return &LiteralIntExpr{Value: value} +func LiteralInt(value int) *LiteralIntExpr { + return &LiteralIntExpr{BaseTreeNode: pt.NewBaseTreeNode(), Value: value} } type OpExpr interface { @@ -130,28 +189,30 @@ type UnaryExpr interface { } type unaryExpr struct { + *pt.BaseTreeNode + name string op string expr LogicalExpr } -func (n *unaryExpr) String() string { - return fmt.Sprintf("%s %s", n.op, n.expr) +func (e *unaryExpr) String() string { + return fmt.Sprintf("%s %s", e.op, e.expr) } -func (n *unaryExpr) Op() string { - return n.op +func (e *unaryExpr) Op() string { + return e.op } -func (n *unaryExpr) Resolve(LogicalPlan) datasource.Field { - return datasource.Field{Name: n.name, Type: "bool"} +func (e *unaryExpr) Resolve(*dt.Schema) dt.Field { + return dt.Field{Name: e.name, Type: "bool"} } -func (n *unaryExpr) E() LogicalExpr { - return n.expr +func (e *unaryExpr) E() LogicalExpr { + return e.expr } -func Not(expr LogicalExpr) UnaryExpr { +func Not(expr LogicalExpr) *unaryExpr { return &unaryExpr{ name: "not", op: "NOT", @@ -174,6 +235,8 @@ type BoolBinaryExpr interface { // boolBinaryExpr represents a binary expression that returns a boolean value. type boolBinaryExpr struct { + *pt.BaseTreeNode + name string op string l LogicalExpr @@ -196,22 +259,23 @@ func (e *boolBinaryExpr) R() LogicalExpr { return e.r } -func (e *boolBinaryExpr) Resolve(LogicalPlan) datasource.Field { - return datasource.Field{Name: e.name, Type: "bool"} +func (e *boolBinaryExpr) Resolve(*dt.Schema) dt.Field { + return dt.Field{Name: e.name, Type: "bool"} } func (e *boolBinaryExpr) returnBool() {} -func And(l, r LogicalExpr) BinaryExpr { +func And(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ - name: "and", - op: "AND", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "and", + op: "AND", + l: l, + r: r, } } -func Or(l, r LogicalExpr) BinaryExpr { +func Or(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ name: "or", op: "OR", @@ -220,57 +284,63 @@ func Or(l, r LogicalExpr) BinaryExpr { } } -func Eq(l, r LogicalExpr) BinaryExpr { +func Eq(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ - name: "eq", - op: "=", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "eq", + op: "=", + l: l, + r: r, } } -func Neq(l, r LogicalExpr) BinaryExpr { +func Neq(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ - name: "neq", - op: "!=", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "neq", + op: "!=", + l: l, + r: r, } } -func Gt(l, r LogicalExpr) BinaryExpr { +func Gt(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ - name: "gt", - op: ">", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "gt", + op: ">", + l: l, + r: r, } } -func Gte(l, r LogicalExpr) BinaryExpr { +func Gte(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ - name: "gte", - op: ">=", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "gte", + op: ">=", + l: l, + r: r, } } -func Lt(l, r LogicalExpr) BinaryExpr { +func Lt(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ - name: "lt", - op: "<", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "lt", + op: "<", + l: l, + r: r, } } -func Lte(l, r LogicalExpr) BinaryExpr { +func Lte(l, r LogicalExpr) *boolBinaryExpr { return &boolBinaryExpr{ - name: "lte", - op: "<=", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "lte", + op: "<=", + l: l, + r: r, } } @@ -283,6 +353,8 @@ type ArithmeticBinaryExpr interface { // arithmeticBinaryExpr represents a binary expression that performs arithmetic // operations, which return type of one of the operands. type arithmeticBinaryExpr struct { + *pt.BaseTreeNode + name string op string l LogicalExpr @@ -305,45 +377,49 @@ func (e *arithmeticBinaryExpr) R() LogicalExpr { return e.r } -func (e *arithmeticBinaryExpr) Resolve(plan LogicalPlan) datasource.Field { - return datasource.Field{Name: e.name, Type: e.l.Resolve(plan).Type} +func (e *arithmeticBinaryExpr) Resolve(schema *dt.Schema) dt.Field { + return dt.Field{Name: e.name, Type: e.l.Resolve(schema).Type} } func (e *arithmeticBinaryExpr) returnOperandType() {} -func Add(l, r LogicalExpr) BinaryExpr { +func Add(l, r LogicalExpr) *arithmeticBinaryExpr { return &arithmeticBinaryExpr{ - name: "add", - op: "+", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "add", + op: "+", + l: l, + r: r, } } -func Sub(l, r LogicalExpr) BinaryExpr { +func Sub(l, r LogicalExpr) *arithmeticBinaryExpr { return &arithmeticBinaryExpr{ - name: "sub", - op: "-", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "sub", + op: "-", + l: l, + r: r, } } -func Mul(l, r LogicalExpr) BinaryExpr { +func Mul(l, r LogicalExpr) *arithmeticBinaryExpr { return &arithmeticBinaryExpr{ - name: "mul", - op: "*", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "mul", + op: "*", + l: l, + r: r, } } -func Div(l, r LogicalExpr) BinaryExpr { +func Div(l, r LogicalExpr) *arithmeticBinaryExpr { return &arithmeticBinaryExpr{ - name: "div", - op: "/", - l: l, - r: r, + BaseTreeNode: pt.NewBaseTreeNode(), + name: "div", + op: "/", + l: l, + r: r, } } @@ -357,43 +433,47 @@ type AggregateExpr interface { // aggregateExpr represents an aggregate expression. // It returns a single value for a group of rows. type aggregateExpr struct { + *pt.BaseTreeNode + name string expr LogicalExpr //NOTE add alias?? } -func (a *aggregateExpr) String() string { - return fmt.Sprintf("%s(%s)", a.name, a.expr) +func (e *aggregateExpr) String() string { + return fmt.Sprintf("%s(%s)", e.name, e.expr) } -func (a *aggregateExpr) Resolve(plan LogicalPlan) datasource.Field { - return datasource.Field{Name: a.name, Type: a.expr.Resolve(plan).Type} +func (e *aggregateExpr) Resolve(schema *dt.Schema) dt.Field { + return dt.Field{Name: e.name, Type: e.expr.Resolve(schema).Type} } -func (a *aggregateExpr) E() LogicalExpr { - return a.expr +func (e *aggregateExpr) E() LogicalExpr { + return e.expr } -func (a *aggregateExpr) aggregate() {} +func (e *aggregateExpr) aggregate() {} -func Max(expr LogicalExpr) AggregateExpr { - return &aggregateExpr{name: "MAX", expr: expr} +func Max(expr LogicalExpr) *aggregateExpr { + return &aggregateExpr{BaseTreeNode: pt.NewBaseTreeNode(), name: "MAX", expr: expr} } -func Min(expr LogicalExpr) AggregateExpr { - return &aggregateExpr{name: "MIN", expr: expr} +func Min(expr LogicalExpr) *aggregateExpr { + return &aggregateExpr{BaseTreeNode: pt.NewBaseTreeNode(), name: "MIN", expr: expr} } -func Avg(expr LogicalExpr) AggregateExpr { - return &aggregateExpr{name: "AVG", expr: expr} +func Avg(expr LogicalExpr) *aggregateExpr { + return &aggregateExpr{BaseTreeNode: pt.NewBaseTreeNode(), name: "AVG", expr: expr} } -func Sum(expr LogicalExpr) AggregateExpr { - return &aggregateExpr{name: "SUM", expr: expr} +func Sum(expr LogicalExpr) *aggregateExpr { + return &aggregateExpr{BaseTreeNode: pt.NewBaseTreeNode(), name: "SUM", expr: expr} } // aggregateIntExpr represents an aggregate expression that returns an integer. type aggregateIntExpr struct { + *pt.BaseTreeNode + name string expr LogicalExpr } @@ -402,8 +482,8 @@ func (a *aggregateIntExpr) String() string { return fmt.Sprintf("%s(%s)", a.name, a.expr) } -func (a *aggregateIntExpr) Resolve(LogicalPlan) datasource.Field { - return datasource.Field{Name: a.name, Type: "int"} +func (a *aggregateIntExpr) Resolve(*dt.Schema) dt.Field { + return dt.Field{Name: a.name, Type: "int"} } func (a *aggregateIntExpr) E() LogicalExpr { @@ -412,8 +492,8 @@ func (a *aggregateIntExpr) E() LogicalExpr { func (a *aggregateIntExpr) aggregate() {} -func Count(expr LogicalExpr) AggregateExpr { - return &aggregateIntExpr{name: "COUNT", expr: expr} +func Count(expr LogicalExpr) *aggregateIntExpr { + return &aggregateIntExpr{BaseTreeNode: pt.NewBaseTreeNode(), name: "COUNT", expr: expr} } type binaryExprBuilder interface { @@ -485,23 +565,213 @@ func (b *binaryExprBuilderImpl) Div(r LogicalExpr) BinaryExpr { } type sortExpr struct { + *pt.BaseTreeNode + expr LogicalExpr asc bool nullsFirst bool } -func (s *sortExpr) String() string { - return fmt.Sprintf("%s %s", s.expr, s.asc) +func (e *sortExpr) String() string { + return fmt.Sprintf("%s %v", e.expr, e.asc) } -func (s *sortExpr) Resolve(plan LogicalPlan) datasource.Field { - return s.expr.Resolve(plan) +func (e *sortExpr) Resolve(schema *dt.Schema) dt.Field { + return e.expr.Resolve(schema) } func SortExpr(expr LogicalExpr, asc, nullsFirst bool) *sortExpr { return &sortExpr{ - expr: expr, - asc: asc, - nullsFirst: nullsFirst, + BaseTreeNode: pt.NewBaseTreeNode(), + expr: expr, + asc: asc, + nullsFirst: nullsFirst, + } +} + +//// pt.TreeNode implementation +// Children() implementation + +func (e *ColumnExpr) Children() []pt.TreeNode { + return []pt.TreeNode{} +} + +func (e *ColumnIdxExpr) Children() []pt.TreeNode { + return []pt.TreeNode{} +} + +func (e *AliasExpr) Children() []pt.TreeNode { + return []pt.TreeNode{e.Expr} +} + +func (e *LiteralStringExpr) Children() []pt.TreeNode { + return []pt.TreeNode{} +} + +func (e *LiteralIntExpr) Children() []pt.TreeNode { + return []pt.TreeNode{} +} + +func (e *unaryExpr) Children() []pt.TreeNode { + return []pt.TreeNode{e.expr} +} + +func (e *boolBinaryExpr) Children() []pt.TreeNode { + return []pt.TreeNode{e.l, e.r} +} + +func (e *arithmeticBinaryExpr) Children() []pt.TreeNode { + return []pt.TreeNode{e.l, e.r} +} + +func (e *aggregateExpr) Children() []pt.TreeNode { + return []pt.TreeNode{e.expr} +} + +func (e *aggregateIntExpr) Children() []pt.TreeNode { + return []pt.TreeNode{e.expr} +} + +func (e *sortExpr) Children() []pt.TreeNode { + return []pt.TreeNode{e.expr} +} + +// TransformChildren() implementation + +func (e *ColumnExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return e +} + +func (e *ColumnIdxExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return e +} + +func (e *AliasExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &AliasExpr{ + BaseTreeNode: pt.NewBaseTreeNode(), + Expr: fn(e.Expr).(LogicalExpr), + Alias: e.Alias, + } +} + +func (e *LiteralStringExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return e +} + +func (e *LiteralIntExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return e +} + +func (e *unaryExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &unaryExpr{ + BaseTreeNode: pt.NewBaseTreeNode(), + name: e.name, + op: e.op, + expr: fn(e.expr).(LogicalExpr), + } +} + +func (e *boolBinaryExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &boolBinaryExpr{ + BaseTreeNode: pt.NewBaseTreeNode(), + name: e.name, + op: e.op, + l: fn(e.l).(LogicalExpr), + r: fn(e.r).(LogicalExpr), + } +} + +func (e *arithmeticBinaryExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &arithmeticBinaryExpr{ + BaseTreeNode: pt.NewBaseTreeNode(), + name: e.name, + op: e.op, + l: fn(e.l).(LogicalExpr), + r: fn(e.r).(LogicalExpr), + } +} + +func (e *aggregateExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &aggregateExpr{ + + BaseTreeNode: pt.NewBaseTreeNode(), + name: e.name, + expr: fn(e.expr).(LogicalExpr), + } +} + +func (e *aggregateIntExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &aggregateIntExpr{ + BaseTreeNode: pt.NewBaseTreeNode(), + name: e.name, + expr: fn(e.expr).(LogicalExpr), + } +} + +func (e *sortExpr) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &sortExpr{ + BaseTreeNode: pt.NewBaseTreeNode(), + expr: fn(e.expr).(LogicalExpr), + asc: e.asc, + nullsFirst: e.nullsFirst, } } + +// ExprNode() implementation + +func (e *ColumnExpr) ExprNode() {} +func (e *ColumnIdxExpr) ExprNode() {} +func (e *AliasExpr) ExprNode() {} +func (e *LiteralStringExpr) ExprNode() {} +func (e *LiteralIntExpr) ExprNode() {} +func (e *unaryExpr) ExprNode() {} +func (e *boolBinaryExpr) ExprNode() {} +func (e *arithmeticBinaryExpr) ExprNode() {} +func (e *aggregateExpr) ExprNode() {} +func (e *aggregateIntExpr) ExprNode() {} +func (e *sortExpr) ExprNode() {} + +///////////////////////////////// + +type OnionOrderVisitor struct{} + +func (v *OnionOrderVisitor) Visit(n pt.TreeNode) (bool, interface{}) { + return pt.OnionOrderVisit(v, n) +} + +func (v *OnionOrderVisitor) PreVisit(n pt.TreeNode) (bool, interface{}) { + panic("implement me") +} + +func (v *OnionOrderVisitor) VisitChildren(n pt.TreeNode) (bool, interface{}) { + return pt.ApplyNodeFuncToChildren(n, v.Visit) +} + +func (v *OnionOrderVisitor) PostVisit(n pt.TreeNode) (bool, interface{}) { + panic("implement me") +} + +func VisitLogicalExpr(expr LogicalExpr, visitor pt.TreeNodeVisitor) (bool, LogicalExpr) { + switch e := expr.(type) { + case pt.TreeNode: + //return visitor.Visit(e) + default: + return true, e.(LogicalExpr) + } + return true, expr +} + +//// exprNodeTransform transforms the expression node using the given function. +//// This is an alternative to the TransformUp method of the expression node. +//func exprNodeTransform(node pt.TreeNode, fn pt.TransformFunc) pt.TreeNode { +// switch e := node.(type) { +// case *ColumnExpr: +// return Column(e.Relation, e.Name) +// case *ColumnIdxExpr: +// return ColumnIdx(e.Idx) +// case *AliasExpr: +// return Alias(fn(e.Expr).(LogicalExpr), e.Alias) +// default: +// panic(fmt.Sprintf("unknown expression type %T", e)) +// } +//} diff --git a/internal/engine/cost/logical_plan/operator.go b/internal/engine/cost/logical_plan/operator.go index b40dcc84c..21f308300 100644 --- a/internal/engine/cost/logical_plan/operator.go +++ b/internal/engine/cost/logical_plan/operator.go @@ -3,8 +3,33 @@ package logical_plan import ( "fmt" "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" ) +// NoFrom represents a no from operator. +// It corresponds to select without any from clause in SQL. +type NoFrom struct{} + +func (n *NoFrom) String() string { + return "NoFrom" +} + +func (n *NoFrom) Schema() *datatypes.Schema { + return datatypes.NewSchema() +} + +func (n *NoFrom) Inputs() []LogicalPlan { + return []LogicalPlan{} +} + +func (n *NoFrom) Exprs() []LogicalExpr { + return []LogicalExpr{} +} + +func NoSource() LogicalPlan { + return &NoFrom{} +} + // ScanOp represents a table scan operator, which produces rows from a table. // It corresponds to `FROM` clause in SQL. type ScanOp struct { @@ -43,7 +68,7 @@ func (s *ScanOp) String() string { return fmt.Sprintf("Scan: %s; projection=%s", s.table, s.projection) } -func (s *ScanOp) Schema() *datasource.Schema { +func (s *ScanOp) Schema() *datatypes.Schema { return s.dataSource.Schema().Select(s.projection...) } @@ -72,12 +97,12 @@ func (p *ProjectionOp) String() string { return fmt.Sprintf("Projection: %s", p.exprs) } -func (p *ProjectionOp) Schema() *datasource.Schema { - fs := make([]datasource.Field, len(p.exprs)) +func (p *ProjectionOp) Schema() *datatypes.Schema { + fs := make([]datatypes.Field, len(p.exprs)) for i, expr := range p.exprs { - fs[i] = expr.Resolve(p.input) + fs[i] = expr.Resolve(p.input.Schema()) } - return datasource.NewSchema(fs...) + return datatypes.NewSchema(fs...) } func (p *ProjectionOp) Inputs() []LogicalPlan { @@ -108,7 +133,7 @@ func (s *SelectionOp) String() string { return fmt.Sprintf("Selection: %s", s.expr) } -func (s *SelectionOp) Schema() *datasource.Schema { +func (s *SelectionOp) Schema() *datatypes.Schema { return s.input.Schema() } @@ -134,14 +159,14 @@ func Selection(plan LogicalPlan, expr LogicalExpr) LogicalPlan { type AggregateOp struct { input LogicalPlan groupBy []LogicalExpr - aggregate []AggregateExpr + aggregate []LogicalExpr } func (a *AggregateOp) GroupBy() []LogicalExpr { return a.groupBy } -func (a *AggregateOp) Aggregate() []AggregateExpr { +func (a *AggregateOp) Aggregate() []LogicalExpr { return a.aggregate } @@ -150,19 +175,19 @@ func (a *AggregateOp) String() string { } // Schema returns groupBy fields and aggregate fields -func (a *AggregateOp) Schema() *datasource.Schema { +func (a *AggregateOp) Schema() *datatypes.Schema { groupByLen := len(a.groupBy) - fs := make([]datasource.Field, len(a.aggregate)+groupByLen) + fs := make([]datatypes.Field, len(a.aggregate)+groupByLen) for i, expr := range a.groupBy { - fs[i] = expr.Resolve(a.input) + fs[i] = expr.Resolve(a.input.Schema()) } for i, expr := range a.aggregate { - fs[i+groupByLen] = expr.Resolve(a.input) + fs[i+groupByLen] = expr.Resolve(a.input.Schema()) } - return datasource.NewSchema(fs...) + return datatypes.NewSchema(fs...) } func (a *AggregateOp) Inputs() []LogicalPlan { @@ -184,7 +209,7 @@ func (a *AggregateOp) Exprs() []LogicalExpr { // Aggregate creates an aggregation logical plan. func Aggregate(plan LogicalPlan, groupBy []LogicalExpr, - aggregateExpr []AggregateExpr) LogicalPlan { + aggregateExpr []LogicalExpr) LogicalPlan { return &AggregateOp{ input: plan, groupBy: groupBy, @@ -213,7 +238,7 @@ func (l *LimitOp) String() string { return fmt.Sprintf("Limit: %d, offset %d", l.limit, l.offset) } -func (l *LimitOp) Schema() *datasource.Schema { +func (l *LimitOp) Schema() *datatypes.Schema { return l.input.Schema() } @@ -251,7 +276,7 @@ func (s *SortOp) String() string { return fmt.Sprintf("Sort: %s", s.by) } -func (s *SortOp) Schema() *datasource.Schema { +func (s *SortOp) Schema() *datatypes.Schema { return s.input.Schema() } @@ -315,13 +340,13 @@ func (j *JoinOp) OpType() JoinType { } // Schema returns the combination of left and right schema -func (j *JoinOp) Schema() *datasource.Schema { +func (j *JoinOp) Schema() *datatypes.Schema { leftFields := j.left.Schema().Fields rightFields := j.right.Schema().Fields - fields := make([]datasource.Field, len(leftFields)+len(rightFields)) + fields := make([]datatypes.Field, len(leftFields)+len(rightFields)) copy(fields, leftFields) copy(fields[len(leftFields):], rightFields) - return datasource.NewSchema(fields...) + return datatypes.NewSchema(fields...) } func (j *JoinOp) Inputs() []LogicalPlan { @@ -386,7 +411,7 @@ func (u *BagOp) String() string { } // Schema returns the schema of the left plan, since they should be the same. -func (u *BagOp) Schema() *datasource.Schema { +func (u *BagOp) Schema() *datatypes.Schema { return u.left.Schema() } @@ -440,7 +465,7 @@ func (s *SubqueryOp) String() string { return fmt.Sprintf("Subquery: %s", s.alias) } -func (s *SubqueryOp) Schema() *datasource.Schema { +func (s *SubqueryOp) Schema() *datatypes.Schema { return s.input.Schema() } @@ -468,7 +493,7 @@ func (d *DistinctOp) String() string { return "Distinct" } -func (d *DistinctOp) Schema() *datasource.Schema { +func (d *DistinctOp) Schema() *datatypes.Schema { return d.input.Schema() } diff --git a/internal/engine/cost/logical_plan/operator_test.go b/internal/engine/cost/logical_plan/operator_test.go index a588a320a..83d211c51 100644 --- a/internal/engine/cost/logical_plan/operator_test.go +++ b/internal/engine/cost/logical_plan/operator_test.go @@ -10,7 +10,8 @@ import ( func ExampleLogicalPlan_String_selection() { ds := datasource.NewMemDataSource(nil, nil) plan := logical_plan.Scan("users", ds, nil) - plan = logical_plan.Projection(plan, logical_plan.Column("", "username"), logical_plan.Column("", "age")) + //plan = logical_plan.Projection(plan, logical_plan.Column("", "username"), logical_plan.Column("", "age")) + plan = logical_plan.Projection(plan, logical_plan.ColumnUnqualified("username"), logical_plan.ColumnUnqualified("age")) fmt.Println(logical_plan.Format(plan, 0)) // Output: // Projection: username, age @@ -20,17 +21,17 @@ func ExampleLogicalPlan_String_selection() { func ExampleLogicalPlan_DataFrame() { ds := datasource.NewMemDataSource(nil, nil) aop := logical_plan.NewDataFrame(logical_plan.Scan("users", ds, nil)) - plan := aop.Filter(logical_plan.Eq(logical_plan.Column("", "age"), logical_plan.LiteralInt(20))). - Aggregate([]logical_plan.LogicalExpr{logical_plan.Column("", "state")}, - []logical_plan.AggregateExpr{logical_plan.Count(logical_plan.Column("", "username"))}). + plan := aop.Filter(logical_plan.Eq(logical_plan.ColumnUnqualified("age"), logical_plan.LiteralInt(20))). + Aggregate([]logical_plan.LogicalExpr{logical_plan.ColumnUnqualified("state")}, + []logical_plan.LogicalExpr{logical_plan.Count(logical_plan.ColumnUnqualified("username"))}). // the alias for aggregate result is bit weird - Project(logical_plan.Column("", "state"), logical_plan.Alias(logical_plan.Count(logical_plan.Column("", "username")), "num")). + Project(logical_plan.ColumnUnqualified("state"), logical_plan.Alias(logical_plan.Count(logical_plan.ColumnUnqualified("username")), "num")). LogicalPlan() fmt.Println(logical_plan.Format(plan, 0)) // Output: // Projection: state, COUNT(username) AS num // Aggregate: [state], [COUNT(username)] - // Selection: [age = 20] + // Selection: age = 20 // Scan: users; projection=[] } diff --git a/internal/engine/cost/logical_plan/plan.go b/internal/engine/cost/logical_plan/plan.go index 33ea3c570..f9cdfb6c6 100644 --- a/internal/engine/cost/logical_plan/plan.go +++ b/internal/engine/cost/logical_plan/plan.go @@ -3,15 +3,14 @@ package logical_plan import ( "bytes" "fmt" - - "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" ) type LogicalPlan interface { fmt.Stringer // Schema returns the schema of the data that will be produced by this LogicalPlan. - Schema() *datasource.Schema + Schema() *datatypes.Schema Inputs() []LogicalPlan @@ -30,10 +29,10 @@ type DataFrameAPI interface { Filter(expr LogicalExpr) DataFrameAPI // Aggregate appliex an aggregation - Aggregate(groupBy []LogicalExpr, aggregateExpr []AggregateExpr) DataFrameAPI + Aggregate(groupBy []LogicalExpr, aggregateExpr []LogicalExpr) DataFrameAPI // Schema returns the schema of the data that will be produced by this DataFrameAPI. - Schema() *datasource.Schema + Schema() *datatypes.Schema // LogicalPlan returns the logical plan LogicalPlan() LogicalPlan @@ -51,11 +50,11 @@ func (df *DataFrame) Filter(expr LogicalExpr) DataFrameAPI { return &DataFrame{Selection(df.plan, expr)} } -func (df *DataFrame) Aggregate(groupBy []LogicalExpr, aggregateExpr []AggregateExpr) DataFrameAPI { +func (df *DataFrame) Aggregate(groupBy []LogicalExpr, aggregateExpr []LogicalExpr) DataFrameAPI { return &DataFrame{Aggregate(df.plan, groupBy, aggregateExpr)} } -func (df *DataFrame) Schema() *datasource.Schema { +func (df *DataFrame) Schema() *datatypes.Schema { return df.plan.Schema() } diff --git a/internal/engine/cost/logical_plan/utils.go b/internal/engine/cost/logical_plan/utils.go index 54c5a3ff6..a8e5fd959 100644 --- a/internal/engine/cost/logical_plan/utils.go +++ b/internal/engine/cost/logical_plan/utils.go @@ -2,7 +2,8 @@ package logical_plan import ( "fmt" - ds "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" + ds "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" + pt "github.com/kwilteam/kwil-db/internal/engine/cost/plantree" ) // SplitConjunction splits the given expression into a list of expressions. @@ -66,3 +67,48 @@ func ExtractColumns(expr LogicalExpr, panic(fmt.Sprintf("unknown expression type %T", e)) } } + +// NormalizeColumn qualifies a column with gaven logical plan. +func NormalizeColumn(plan LogicalPlan, column *ColumnExpr) *ColumnExpr { + return column.QualifyWithSchemas(plan.Schema()) +} + +// NormalizeExpr normalizes the given expression with the given logical plan. +func NormalizeExpr(expr LogicalExpr, plan LogicalPlan) LogicalExpr { + e := expr.TransformUp(func(n pt.TreeNode) pt.TreeNode { + if c, ok := n.(*ColumnExpr); ok { + return NormalizeColumn(plan, c) + } + return n + }) + + return e.(LogicalExpr) +} + +func NormalizeExprs(exprs []LogicalExpr, plan LogicalPlan) []LogicalExpr { + normalized := make([]LogicalExpr, len(exprs)) + for i, e := range exprs { + normalized[i] = NormalizeExpr(e, plan) + } + return normalized +} + +func ResolveColumns(expr LogicalExpr, plan LogicalPlan) LogicalExpr { + return expr.TransformUp(func(n pt.TreeNode) pt.TreeNode { + if c, ok := n.(*ColumnExpr); ok { + c.QualifyWithSchemas(plan.Schema()) + } + return n + }).(LogicalExpr) +} + +func ColumnFromDefToExpr(column *ds.ColumnDef) *ColumnExpr { + return Column(column.Relation, column.Name) +} + +func ColumnFromExprToDef(column *ColumnExpr) *ds.ColumnDef { + return &ds.ColumnDef{ + Relation: column.Relation, + Name: column.Name, + } +} diff --git a/internal/engine/cost/logical_plan/utils_test.go b/internal/engine/cost/logical_plan/utils_test.go index 86815d2c7..e288c14fd 100644 --- a/internal/engine/cost/logical_plan/utils_test.go +++ b/internal/engine/cost/logical_plan/utils_test.go @@ -1,6 +1,7 @@ package logical_plan import ( + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" "github.com/stretchr/testify/assert" "reflect" "testing" @@ -10,6 +11,9 @@ func TestSplitConjunction(t *testing.T) { type args struct { expr LogicalExpr } + + t1 := datatypes.TableRefFromTable("t1") + tests := []struct { name string args args @@ -19,74 +23,74 @@ func TestSplitConjunction(t *testing.T) { name: "1 level AND", args: args{ expr: And( - Column("t1", "a"), - Column("t1", "b"), + Column(t1, "a"), + Column(t1, "b"), ), }, want: []LogicalExpr{ - Column("t1", "a"), - Column("t1", "b"), + Column(t1, "a"), + Column(t1, "b"), }, }, { name: "2 level AND", args: args{ expr: And( - Column("t1", "a"), + Column(t1, "a"), And( - Column("t1", "b"), - Column("t1", "c"), + Column(t1, "b"), + Column(t1, "c"), ), ), }, want: []LogicalExpr{ - Column("t1", "a"), - Column("t1", "b"), - Column("t1", "c"), + Column(t1, "a"), + Column(t1, "b"), + Column(t1, "c"), }, }, { name: "with alias", args: args{ expr: And( - Alias(Column("t1", "a"), "a"), - Alias(Column("t1", "b"), "b"), + Alias(Column(t1, "a"), "a"), + Alias(Column(t1, "b"), "b"), ), }, want: []LogicalExpr{ - Column("t1", "a"), - Column("t1", "b"), + Column(t1, "a"), + Column(t1, "b"), }, }, { name: "with binary expr", args: args{ expr: And( - Column("t1", "a"), - Eq(Column("t1", "b"), LiteralInt(1)), + Column(t1, "a"), + Eq(Column(t1, "b"), LiteralInt(1)), ), }, want: []LogicalExpr{ - Column("t1", "a"), - Eq(Column("t1", "b"), LiteralInt(1)), + Column(t1, "a"), + Eq(Column(t1, "b"), LiteralInt(1)), }, }, { name: "no conjunction", args: args{ - expr: Eq(Column("t1", "a"), LiteralInt(1)), + expr: Eq(Column(t1, "a"), LiteralInt(1)), }, want: []LogicalExpr{ - Eq(Column("t1", "a"), LiteralInt(1)), + Eq(Column(t1, "a"), LiteralInt(1)), }, }, { name: "no conjunction with alias", args: args{ - expr: Alias(Eq(Column("t1", "a"), LiteralInt(1)), "a"), + expr: Alias(Eq(Column(t1, "a"), LiteralInt(1)), "a"), }, want: []LogicalExpr{ - Eq(Column("t1", "a"), LiteralInt(1)), + Eq(Column(t1, "a"), LiteralInt(1)), }, }, } @@ -103,6 +107,8 @@ func TestConjunction(t *testing.T) { type args struct { exprs []LogicalExpr } + t1 := datatypes.TableRefFromTable("t1") + tests := []struct { name string args args @@ -112,28 +118,28 @@ func TestConjunction(t *testing.T) { name: "1 level AND", args: args{ exprs: []LogicalExpr{ - Column("t1", "a"), - Column("t1", "b"), + Column(t1, "a"), + Column(t1, "b"), }, }, wantExpr: And( - Column("t1", "a"), - Column("t1", "b"), + Column(t1, "a"), + Column(t1, "b"), ), }, { name: "2 level AND", args: args{ exprs: []LogicalExpr{ - Column("t1", "a"), - Column("t1", "b"), - Column("t1", "c"), + Column(t1, "a"), + Column(t1, "b"), + Column(t1, "c"), }, }, wantExpr: And( - And(Column("t1", "a"), - Column("t1", "b")), - Column("t1", "c"), + And(Column(t1, "a"), + Column(t1, "b")), + Column(t1, "c"), ), }, } diff --git a/internal/engine/cost/plantree/tree.go b/internal/engine/cost/plantree/tree.go new file mode 100644 index 000000000..63e352cbf --- /dev/null +++ b/internal/engine/cost/plantree/tree.go @@ -0,0 +1,294 @@ +package plantree + +import "fmt" + +type NodeFunc func(TreeNode) (bool, any) +type TransformFunc func(TreeNode) TreeNode + +// Tree represents a node in a tree, which is visitable. +type Tree interface { + Children() []TreeNode + ShallowClone() TreeNode + fmt.Stringer +} + +type TreeNode interface { + Tree + + // Accept visits the node and its children using the provided TreeNodeVisitor. + Accept(TreeNodeVisitor) (keepGoing bool, res any) + + //// Apply walks through the node and its children and applies the provided NodeFunc. + //// The point is that you can quickly apply a function to the whole tree. + //// It's a light version of Accept. + //Apply(NodeFunc) (bool, any) + + // TransformUp applies the provided TransformFunc to copied node, and apply + // it to its children by calling TransformChildren, returns transformed node. + // It traverses the tree in DFS post-order. + TransformUp(TransformFunc) TreeNode + // TransformDown is like TransformUp, but it traverses the tree in DFS pre-order. + TransformDown(TransformFunc) TreeNode + // TransformChildren applies the provided TransformFunc to copied children. + TransformChildren(TransformFunc) TreeNode +} + +type ExprNode interface { + TreeNode + + ExprNode() +} + +type PlanNode interface { + TreeNode + + PlanNode() +} + +//type BaseNode struct { +// children []Tree +//} +// +//func (n *BaseNode) Children() []Tree { +// return n.children +//} +// +//// Accept visits the node and its children using the provided TreeNodeVisitor. +//// It traverses the tree in DFS pre-order. +//func (n *BaseNode) Accept(v TreeNodeVisitor) (bool, any) { +// keepGoing, res := v.PreVisit(n) +// if !keepGoing { +// return false, res +// } +// +// keepGoing, res = n.visitChildren(v) +// if !keepGoing { +// return false, res +// } +// +// return v.PostVisit(n) +//} +// +//func (n *BaseNode) visitChildren(v TreeNodeVisitor) (bool, any) { +// for _, child := range n.children { +// keepGoing, res := child.Accept(v) +// if !keepGoing { +// return false, res +// } +// } +// return true, nil +//} +// +//func (n *BaseNode) TransformUp(f TransformFunc) Tree { +// return n.transformPostOrder(f) +//} +// +//func (n *BaseNode) transformPostOrder(f TransformFunc) Tree { +// transformed := n.TransformChildren(f) +// return f(transformed) +//} +// +//func (n *BaseNode) TransformChildren(f TransformFunc) Tree { +// panic("not implemented") +//} + +////type RecursiveNext int8 +//// +////const ( +//// RecursiveNextStop RecursiveNext = iota +//// RecursiveNextContinue +//// RecursiveNextSkip +////) +// +//// TreeNodeVisitor implements the visitor pattern for walking Tree recursively. +//type TreeNodeVisitor interface { +// // PreVisit is called before visiting the children of the node. +// PreVisit(node Tree) (bool, any) +// +// // PostVisit is called after visiting the children of the node. +// PostVisit(node Tree) (bool, any) +//} +// +//type BaseNodeVisitor struct{} +// +//func (v *BaseNodeVisitor) PreVisit(node Tree) (bool, any) { +// return true, nil +//} +// +//func (v *BaseNodeVisitor) PostVisit(node Tree) (bool, any) { +// return true, nil +//} +// +//type BaseTreeNode struct{} +// +//func (n *BaseTreeNode) Children() []TreeNode { +// panic("not implemented") +//} +// +//func (n *BaseTreeNode) Accept(v TreeNodeVisitor) (bool, any) { +// keepGoing, res := v.PreVisit(n) +// if !keepGoing { +// return false, res +// } +// +// keepGoing, res = v.VisitChildren(n) +// if !keepGoing { +// return false, res +// } +// +// return v.PostVisit(n) +//} + +// OnionOrderVisit visits the tree in onion order, ((())) like. +func OnionOrderVisit(v TreeNodeVisitor, node TreeNode) (bool, any) { + keepGoing, res := v.PreVisit(node) + if !keepGoing { + return false, res + } + + keepGoing, res = v.VisitChildren(node) + if !keepGoing { + return false, res + } + + return v.PostVisit(node) +} + +func ApplyNodeFuncToChildren(node TreeNode, fn NodeFunc) (bool, any) { + for _, child := range node.Children() { + keepGoing, res := fn(child) + if !keepGoing { + return false, res + } + } + return true, nil +} + +// PreOrderApply walks through the node and its children and applies the +// provided NodeFunc, in pre-order. +// The point is that you can quickly apply a function to the whole tree. +// It's a light version of TreeNodeVisitor. +func PreOrderApply(node TreeNode, fn NodeFunc) (bool, any) { + keepGoing, res := fn(node) + if !keepGoing { + return false, res + } + + return ApplyNodeFuncToChildren(node, + func(node TreeNode) (bool, any) { + return PreOrderApply(node, fn) + }) +} + +// PostOrderApply walks through the node and its children and applies the +// provided NodeFunc, in post-order. +// The point is that you can quickly apply a function to the whole tree. +// It's a light version of TreeNodeVisitor. +// PostOrder means all children are visited before the node itself. +func PostOrderApply(node TreeNode, fn NodeFunc) (bool, any) { + keepGoing, res := ApplyNodeFuncToChildren(node, + func(node TreeNode) (bool, any) { + return PostOrderApply(node, fn) + }) + if !keepGoing { + return false, res + } + + return fn(node) +} + +// NodeTransformFunc is a function that transforms a node and its children using +// the provided TransformFunc. +type NodeTransformFunc func(node TreeNode, transformFunc TransformFunc) TreeNode + +func PostOrderTransform(node TreeNode, fn TransformFunc, nodeFn NodeTransformFunc) TreeNode { + newChildren := nodeFn(node, func(n TreeNode) TreeNode { + return PostOrderTransform(n, fn, nodeFn) + }) + + return fn(newChildren) +} + +func PreOrderTransform(node TreeNode, fn TransformFunc, nodeFn NodeTransformFunc) TreeNode { + newNode := fn(node) + + return nodeFn(newNode, func(n TreeNode) TreeNode { + return fn(n) + }) +} + +type TreeNodeVisitor interface { + Visit(TreeNode) (bool, any) + PreVisit(TreeNode) (bool, any) + VisitChildren(TreeNode) (bool, any) + PostVisit(TreeNode) (bool, any) +} + +// +//type BaseTreeVisitor struct{} +// +//func (v *BaseTreeVisitor) Visit(node TreeNode) (bool, any) { +// return true, node.Accept(v) +//} +// +//func (v *BaseTreeVisitor) VisitChildren(node TreeNode) (bool, any) { +// return O +//} +// +//func (v *BaseTreeVisitor) PreVisit(node TreeNode) (bool, any) { +// return true, nil +//} +// +//func (v *BaseTreeVisitor) PostVisit(node TreeNode) (bool, any) { +// return true, nil +//} + +type BaseTreeNode struct{} + +func (n *BaseTreeNode) String() string { + return fmt.Sprintf("%T", n) +} + +func (n *BaseTreeNode) Children() []TreeNode { + panic("implement me") +} + +func (n *BaseTreeNode) ShallowClone() TreeNode { + nn := *n + return &nn +} + +func (n *BaseTreeNode) Accept(v TreeNodeVisitor) (bool, interface{}) { + return v.Visit(n) +} + +func (n *BaseTreeNode) Apply(fn NodeFunc) (bool, any) { + return PreOrderApply(n, fn) +} + +// Transform applies the provided TransformFunc to copied node in post-order. +// NOTE: this should be implemented by the concrete node, otherwise it won't +// call concrete node's TransformChildren. +func (n *BaseTreeNode) TransformUp(fn TransformFunc) TreeNode { + newChildren := n.TransformChildren(func(node TreeNode) TreeNode { + return n.TransformUp(fn) + }) + + return fn(newChildren) +} + +func (n *BaseTreeNode) TransformDown(fn TransformFunc) TreeNode { + newNode := fn(n) + + return newNode.TransformChildren(func(node TreeNode) TreeNode { + return node.TransformDown(fn) + }) +} + +func (n *BaseTreeNode) TransformChildren(fn TransformFunc) TreeNode { + panic("implement me") +} + +func NewBaseTreeNode() *BaseTreeNode { + return &BaseTreeNode{} +} diff --git a/internal/engine/cost/plantree/tree_test.go b/internal/engine/cost/plantree/tree_test.go new file mode 100644 index 000000000..4f2eb2e2e --- /dev/null +++ b/internal/engine/cost/plantree/tree_test.go @@ -0,0 +1,318 @@ +package plantree_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + pt "github.com/kwilteam/kwil-db/internal/engine/cost/plantree" +) + +type mockValueNode struct { + *pt.BaseTreeNode + + value any +} + +func (n *mockValueNode) Children() []pt.TreeNode { + return []pt.TreeNode{} +} + +func (n *mockValueNode) Accept(v pt.TreeNodeVisitor) (bool, any) { + return v.Visit(n) +} + +func (n *mockValueNode) TransformUp(fn pt.TransformFunc) pt.TreeNode { + newChildren := n.TransformChildren(func(node pt.TreeNode) pt.TreeNode { + return node.TransformUp(fn) + }) + + return fn(newChildren) +} + +func (n *mockValueNode) TransformDown(fn pt.TransformFunc) pt.TreeNode { + newNode := fn(n) + + return newNode.TransformChildren(func(node pt.TreeNode) pt.TreeNode { + return node.TransformDown(fn) + }) +} + +func (n *mockValueNode) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: n.value, + } +} + +func (n *mockValueNode) String() string { + return fmt.Sprintf("%v", n.value) +} + +type mockBinaryTreeNode struct { + *pt.BaseTreeNode + + left pt.TreeNode + right pt.TreeNode +} + +func (n *mockBinaryTreeNode) Children() []pt.TreeNode { + return []pt.TreeNode{n.left, n.right} +} + +func (n *mockBinaryTreeNode) Accept(v pt.TreeNodeVisitor) (bool, any) { + return v.Visit(n) +} + +func (n *mockBinaryTreeNode) TransformUp(fn pt.TransformFunc) pt.TreeNode { + newChildren := n.TransformChildren(func(node pt.TreeNode) pt.TreeNode { + return node.TransformUp(fn) + }) + + return fn(newChildren) +} + +func (n *mockBinaryTreeNode) TransformDown(fn pt.TransformFunc) pt.TreeNode { + newNode := fn(n) + + return newNode.TransformChildren(func(node pt.TreeNode) pt.TreeNode { + return node.TransformDown(fn) + }) +} + +func (n *mockBinaryTreeNode) TransformChildren(fn pt.TransformFunc) pt.TreeNode { + return &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + + left: fn(n.left), + right: fn(n.right), + } +} + +func (n *mockBinaryTreeNode) String() string { + return fmt.Sprintf("(%v, %v)", n.left, n.right) +} + +func mockLeftTree() *mockBinaryTreeNode { + // /\ + // /\ 4 + // /\ 3 + // 1 2 + return &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + left: &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + left: &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + left: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 1, + }, + right: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 2, + }, + }, + right: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 3, + }, + }, + right: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 4, + }, + } +} + +func mockRightTree() *mockBinaryTreeNode { + // /\ + // 4 /\ + // 3 /\ + // 1 2 + return &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + left: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 4, + }, + right: &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + left: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 3, + }, + right: &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + left: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 1, + }, + right: &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: 2, + }, + }, + }, + } +} + +func mockApplyFuncPreOrderCollect(n pt.TreeNode) []any { + collected := []any{} + + pt.PreOrderApply(n, func(n pt.TreeNode) (bool, any) { + if v, ok := n.(*mockValueNode); ok { + collected = append(collected, v.value) + return true, v.value + } + return true, n + }) + + return collected +} + +func mockApplyFuncPostOrderCollect(n pt.TreeNode) []any { + collected := []any{} + + pt.PostOrderApply(n, func(n pt.TreeNode) (bool, any) { + if v, ok := n.(*mockValueNode); ok { + collected = append(collected, v.value) + return true, v.value + } + return true, n + }) + + return collected +} + +func TestOrderApply_left_tree(t *testing.T) { + node := mockLeftTree() + + assert.Equal(t, []any{1, 2, 3, 4}, mockApplyFuncPreOrderCollect(node)) + assert.Equal(t, []any{1, 2, 3, 4}, mockApplyFuncPostOrderCollect(node)) +} + +func TestOrderApply_right_tree(t *testing.T) { + node := mockRightTree() + + assert.Equal(t, []any{4, 3, 1, 2}, mockApplyFuncPreOrderCollect(node)) + assert.Equal(t, []any{4, 3, 1, 2}, mockApplyFuncPostOrderCollect(node)) +} + +func mockTransform(node pt.TreeNode) pt.TreeNode { + return node.TransformUp(func(n pt.TreeNode) pt.TreeNode { + if v, ok := n.(*mockValueNode); ok { + return &mockValueNode{ + value: v.value.(int) * 2, + } + } + // otherwise, return the original node + return n + }) +} + +func TestTransform_left_tree(t *testing.T) { + node := mockLeftTree() + + transformed := mockTransform(node) + // new tree's nodes have been transformed + assert.Equal(t, []any{2, 4, 6, 8}, mockApplyFuncPreOrderCollect(transformed)) + // original tree's nodes are not changed + assert.Equal(t, []any{1, 2, 3, 4}, mockApplyFuncPreOrderCollect(node)) + + leftNode := node.left + leftTransformed := mockTransform(leftNode) + // new left tree's nodes have been transformed + assert.Equal(t, []any{2, 4, 6}, mockApplyFuncPreOrderCollect(leftTransformed)) + // original left tree's nodes are not changed + assert.Equal(t, []any{1, 2, 3}, mockApplyFuncPreOrderCollect(leftNode)) + // original parent tree's nodes are not changed + assert.Equal(t, []any{1, 2, 3, 4}, mockApplyFuncPreOrderCollect(node)) +} + +func mockNodeTransformFunc(node pt.TreeNode, transformFunc pt.TransformFunc) pt.TreeNode { + switch t := node.(type) { + case *mockValueNode: + return &mockValueNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + value: t.value, + } + case *mockBinaryTreeNode: + return &mockBinaryTreeNode{ + BaseTreeNode: pt.NewBaseTreeNode(), + left: transformFunc(t.left), + right: transformFunc(t.right), + } + default: + panic("unknown node type") + } +} + +func TestTransform_left_tree_using_fn(t *testing.T) { + + node := mockLeftTree() + + mockTransFn := func(n pt.TreeNode) pt.TreeNode { + if v, ok := n.(*mockValueNode); ok { + return &mockValueNode{ + value: v.value.(int) * 2, + } + } + + // otherwise, return the original node + return n + } + + transformed := pt.PostOrderTransform(node, mockTransFn, mockNodeTransformFunc) + // or + //transformed := mockTransform(node) + + // new tree's nodes have been transformed + assert.Equal(t, []any{2, 4, 6, 8}, mockApplyFuncPreOrderCollect(transformed)) + // original tree's nodes are not changed + assert.Equal(t, []any{1, 2, 3, 4}, mockApplyFuncPreOrderCollect(node)) + + leftNode := node.left + leftTransformed := pt.PostOrderTransform(leftNode, mockTransFn, mockNodeTransformFunc) + // or + //leftTransformed := mockTransform(leftNode) + + // new left tree's nodes have been transformed + assert.Equal(t, []any{2, 4, 6}, mockApplyFuncPreOrderCollect(leftTransformed)) + // original left tree's nodes are not changed + assert.Equal(t, []any{1, 2, 3}, mockApplyFuncPreOrderCollect(leftNode)) + // original parent tree's nodes are not changed + assert.Equal(t, []any{1, 2, 3, 4}, mockApplyFuncPreOrderCollect(node)) +} + +type cloneB struct { + a int +} + +type cloneA struct { + b *cloneB +} + +func (c *cloneA) Clone1() *cloneA { + bb := *c.b + return &cloneA{b: &bb} +} + +func (c *cloneA) Clone2() *cloneA { + cc := *c + return &cc +} + +func TestClone(t *testing.T) { + a := &cloneA{b: &cloneB{a: 1}} + b := a.Clone1() + assert.Equal(t, 1, b.b.a) + + c := a.Clone2() + assert.Equal(t, 1, c.b.a) + + fmt.Printf("a: %p, a.b: %p\n", a, a.b) + fmt.Printf("b: %p, b.b: %p\n", b, b.b) + fmt.Printf("c: %p, c.b: %p\n", c, c.b) +} diff --git a/internal/engine/cost/query_planner/planner.go b/internal/engine/cost/query_planner/planner.go index d7a07f5c6..314ad001f 100644 --- a/internal/engine/cost/query_planner/planner.go +++ b/internal/engine/cost/query_planner/planner.go @@ -2,11 +2,13 @@ package query_planner import ( "fmt" + "slices" "strconv" "strings" - ds "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" + ds "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" lp "github.com/kwilteam/kwil-db/internal/engine/cost/logical_plan" + pt "github.com/kwilteam/kwil-db/internal/engine/cost/plantree" "github.com/kwilteam/kwil-db/parse/sql/tree" ) @@ -21,6 +23,8 @@ func NewPlanner() *queryPlanner { return &queryPlanner{} } +// ToExpr converts a tree.Expression to a logical expression. +// TODO: use iterator or stack to traverse the tree, instead of recursive, to avoid stack overflow. func (q *queryPlanner) ToExpr(expr tree.Expression, schema *ds.Schema) lp.LogicalExpr { switch e := expr.(type) { case *tree.ExpressionLiteral: @@ -35,10 +39,8 @@ func (q *queryPlanner) ToExpr(expr tree.Expression, schema *ds.Schema) lp.Logica return &lp.LiteralIntExpr{Value: i} } case *tree.ExpressionColumn: - return &lp.ColumnExpr{ - Table: e.Table, - Name: e.Column, - } + // TODO: handle relation + return lp.ColumnUnqualified(e.Column) //case *tree.ExpressionFunction: case *tree.ExpressionUnary: switch e.Operator { @@ -147,7 +149,7 @@ func (q *queryPlanner) buildSelect(node *tree.SelectStmt, ctx *PlannerContext) l case tree.CompoundOperatorTypeExcept: plan = lp.Builder.From(left).Except(right).Build() default: - panic(fmt.Sprintf("unknown set operation %s", setOp)) + panic(fmt.Sprintf("unknown set operation %s", setOp.ToSQL())) } left = plan } @@ -237,25 +239,111 @@ func (q *queryPlanner) buildLimit(plan lp.LogicalPlan, node *tree.Limit) lp.Logi return lp.Builder.From(plan).Limit(offset, limit).Build() } +// buildSelectPlan builds a logical plan for a select statement. +// The order of building is: +// 1. from +// 2. where +// 3. group by(can use reference from select) +// 4. having(can use reference from select) +// 5. select +// 6. distinct +// 7. order by +// 8. limit func (q *queryPlanner) buildSelectPlan(node *tree.SelectCore, ctx *PlannerContext) lp.LogicalPlan { var plan lp.LogicalPlan + // from clause plan = q.buildFrom(node.From, ctx) - plan = q.buildFilter(plan, node.Where) // where + noFrom := false + if _, ok := plan.(*lp.NoFrom); ok { + noFrom = true + } + + // where clause + // after this step, we got a schema(maybe combined from different tables) to work with + sourcePlan := q.buildFilter(plan, node.Where, ctx) + + // try qualify expr, also expand `*` + projectExprs := q.prepareProjectionExprs(sourcePlan, node.Columns, noFrom, ctx) - // expand * in select list + // for having/group_by exprs + aliasMap := extractAliases(projectExprs) + projectedPlan := q.buildProjection(sourcePlan, projectExprs) + + combinedSchema := sourcePlan.Schema().Clone().Merge(projectedPlan.Schema()) + + ///////////// + // THIS IS WHERE I LEFT!!!!!!!! + var havingExpr lp.LogicalExpr if node.GroupBy != nil { - plan = b.buildAggregate(plan, node.GroupBy, node.Columns) // group by - plan = b.buildFilter(plan, node.GroupBy.Having) // having + havingExpr = q.buildHaving(node.GroupBy.Having, combinedSchema, aliasMap, ctx) } - // if orderBy , project for order + aggrExprs := slices.Clone(projectExprs) // shallow copy + if havingExpr != nil { + aggrExprs = append(aggrExprs, havingExpr) + } + aggrExprs = extractAggrExprs(aggrExprs) - plan = b.buildDistinct(plan, node.SelectType, node.Columns) // distinct + var groupByExprs []lp.LogicalExpr + if node.GroupBy != nil { + for _, gbExpr := range node.GroupBy.Expressions { + groupByExpr := q.ToExpr(gbExpr, combinedSchema) + + // avoid conflict + aliasMapClone := cloneAliases(aliasMap) + for _, f := range sourcePlan.Schema().Fields { + delete(aliasMapClone, f.Name) + } + + groupByExpr = resolveAlias(groupByExpr, aliasMapClone) + if err := ensureSchemaSatifiesExprs(combinedSchema, []lp.LogicalExpr{groupByExpr}); err != nil { + panic(err) + } + + groupByExprs = append(groupByExprs, groupByExpr) + } + } - plan = b.buildProjection(plan, orderBy, node.Columns) // project + var planAfterAggr lp.LogicalPlan + var projectedExpsAfterAggr []lp.LogicalExpr + + if len(groupByExprs) > 0 || len(aggrExprs) > 0 { + planAfterAggr, projectedExpsAfterAggr = q.buildAggregate( + sourcePlan, projectExprs, havingExpr, groupByExprs, aggrExprs) + } else { + if havingExpr != nil { + panic("having expression without group by") + } + } + + //////////// + + // another projection + plan = q.buildProjection(planAfterAggr, projectedExpsAfterAggr) + + // distinct + if node.SelectType == tree.SelectTypeDistinct { + plan = lp.Builder.From(plan).Distinct().Build() + } + + ////////// + + //if node.GroupBy != nil { + // plan = b.buildAggregate(plan, node.GroupBy, node.Columns) // group by + // plan = b.buildFilter(plan, node.GroupBy.Having) // having + //} + // + //// if orderBy , project for order + // + //plan = b.buildDistinct(plan, node.SelectType, node.Columns) // distinct + + //// TODO: handle group by,distinct, order by, limit + //newPlan := projectedPlan + //var projectExprAfterAggr []lp.LogicalExpr + //plan = q.buildProjection(newPlan, projectExprAfterAggr) // final project // done in VisitSelectStmt and VisitTableOrSubQuerySelect //plan = b.buildSort() // order by @@ -325,7 +413,12 @@ func (q *queryPlanner) buildTableSource(node *tree.RelationTable, ctx *PlannerCo //} //return nil - //return lp.Builder.From(node.Table).Build() + //tableRef, err := relationNameToTableRef(node.Name) + //if err != nil { + // panic(err) + //} + // + //return lp.Builder.From(node.Relation).Build() return nil } @@ -337,32 +430,123 @@ func (q *queryPlanner) buildFilter(plan lp.LogicalPlan, node tree.Expression, ct // TODO: handle parent schema expr := q.ToExpr(node, plan.Schema()) - seen := make(map[string]bool) - extractColumnsFromFilterExpr(expr, seen) + //seen := make(map[*lp.ColumnExpr]bool) + //extractColumnsFromFilterExpr(expr, seen) + //expr = qualifyExpr(expr, seen, plan.Schema()) + expr = qualifyExpr(expr, plan.Schema()) + return lp.Builder.From(plan).Select(expr).Build() +} - return lp.Builder.From(plan).Filter(expr).Build() +func (q *queryPlanner) buildProjection(plan lp.LogicalPlan, exprs []lp.LogicalExpr) lp.LogicalPlan { + return lp.Builder.From(plan).Select(exprs...).Build() } -// extractColumnsFromFilterExpr extracts the columns are references by the filter expression. -// It keeps track of the columns that have been seen in the 'seen' map. -func extractColumnsFromFilterExpr(expr lp.LogicalExpr, seen map[string]bool) { - switch e := expr.(type) { - case *lp.LiteralStringExpr: - case *lp.LiteralIntExpr: - case *lp.AliasExpr: - extractColumnsFromFilterExpr(e.Expr, seen) - case lp.UnaryExpr: - extractColumnsFromFilterExpr(e.E(), seen) - case lp.AggregateExpr: - extractColumnsFromFilterExpr(e.E(), seen) - case lp.BinaryExpr: - extractColumnsFromFilterExpr(e.L(), seen) - extractColumnsFromFilterExpr(e.R(), seen) - case *lp.ColumnExpr: - seen[e.Name] = true - //case *.ColumnIdxExpr: - // seen[input.Schema().Fields[e.Idx].Name] = true +func (q *queryPlanner) buildHaving(node tree.Expression, schema *ds.Schema, + aliasMap map[string]lp.LogicalExpr, ctx *PlannerContext) lp.LogicalExpr { + if node == nil { + return nil + } + + expr := q.ToExpr(node, schema) + expr = resolveAlias(expr, aliasMap) + expr = qualifyExpr(expr, schema) + return expr +} + +// buildAggregate builds a logical plan for an aggregate. +// A typical aggregate plan has group by, having, and aggregate expressions. +func (q *queryPlanner) buildAggregate(input lp.LogicalPlan, + projectedExprs []lp.LogicalExpr, havingExpr lp.LogicalExpr, + groupByExprs, aggrExprs []lp.LogicalExpr) (lp.LogicalPlan, []lp.LogicalExpr) { + plan := lp.Builder.From(input).Aggregate(groupByExprs, aggrExprs).Build() + if p, ok := plan.(*lp.AggregateOp); ok { + // rewrite projection to refer to columns that are output of aggregate plan. + plan = p + groupByExprs = p.GroupBy() + } else { + panic(fmt.Sprintf("unexpected plan type %T", plan)) + } + + // rewrite projection to refer to columns that are output of aggregate plan. + // + aggrProjectionExprs := slices.Clone(groupByExprs) + aggrProjectionExprs = append(aggrProjectionExprs, aggrExprs...) + // resolve the columns in projection + resolvedAggrProjectionExprs := make([]lp.LogicalExpr, len(aggrProjectionExprs)) + for i, expr := range aggrProjectionExprs { + e := expr.TransformUp(func(n pt.TreeNode) pt.TreeNode { + if c, ok := n.(*lp.ColumnExpr); ok { + field := c.Resolve(plan.Schema()) + return lp.ColumnFromDefToExpr(field.QualifiedColumn()) + } + return n + }) + + resolvedAggrProjectionExprs[i] = e.(lp.LogicalExpr) + } + // replace any expressions that are not a column with a column + // like `1+2` or `group by a+b`(a,b are alias) + var columnsAfterAggr []lp.LogicalExpr + for _, expr := range resolvedAggrProjectionExprs { + columnsAfterAggr = append(columnsAfterAggr, exprAsColumn(expr, plan)) + } + // + // rewrite projection + var projectedExprsAfterAggr []lp.LogicalExpr + for _, expr := range projectedExprs { + projectedExprsAfterAggr = append(projectedExprsAfterAggr, + rebaseExprs(expr, resolvedAggrProjectionExprs, plan)) + } + // make sure projection exprs can be resolved from columns + + if err := checkExprsProjectFromColumns(projectedExprsAfterAggr, + columnsAfterAggr); err != nil { + panic(fmt.Sprintf("build aggregation: %s", err)) + } + + if havingExpr != nil { + havingExpr = rebaseExprs(havingExpr, resolvedAggrProjectionExprs, plan) + if err := checkExprsProjectFromColumns( + []lp.LogicalExpr{havingExpr}, columnsAfterAggr); err != nil { + panic(fmt.Sprintf("build aggregation: %s", err)) + } + + plan = lp.Builder.From(plan).Select(havingExpr).Build() + } + + return plan, projectedExprsAfterAggr +} + +func (q *queryPlanner) prepareProjectionExprs(plan lp.LogicalPlan, node []tree.ResultColumn, noFrom bool, ctx *PlannerContext) []lp.LogicalExpr { + var exprs []lp.LogicalExpr + for _, col := range node { + exprs = append(exprs, q.projectColumnToExpr(col, plan, noFrom, ctx)...) + } + return exprs +} + +func (q *queryPlanner) projectColumnToExpr(col tree.ResultColumn, plan lp.LogicalPlan, noFrom bool, ctx *PlannerContext) []lp.LogicalExpr { + switch t := col.(type) { + case *tree.ResultColumnExpression: // single column + expr := q.ToExpr(t.Expression, plan.Schema()) + column := qualifyExpr(expr, nil, plan.Schema()) + if t.Alias != "" { // only add alias if it's not the same as column name + if c, ok := column.(*lp.ColumnExpr); ok { + if c.Name != t.Alias { + column = lp.Alias(column, t.Alias) + } + } + } + return []lp.LogicalExpr{column} + case *tree.ResultColumnStar: // expand * + if noFrom { + panic("cannot use * in select list without FROM clause") + } + + return expandStar(plan.Schema()) + case *tree.ResultColumnTable: // expand table.* + return expandQualifiedStar(plan.Schema(), t.TableName) default: - panic(fmt.Sprintf("unknown expression type %T", e)) + panic(fmt.Sprintf("unknown result column type %T", t)) } } diff --git a/internal/engine/cost/query_planner/planner_test.go b/internal/engine/cost/query_planner/planner_test.go new file mode 100644 index 000000000..4c6cc3714 --- /dev/null +++ b/internal/engine/cost/query_planner/planner_test.go @@ -0,0 +1,42 @@ +package query_planner + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/kwilteam/kwil-db/internal/engine/cost/logical_plan" + sqlparser "github.com/kwilteam/kwil-db/parse/sql" +) + +func Test_queryPlanner_ToPlan(t *testing.T) { + type args struct { + stmt string + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + { + name: "simple select", + args: args{ + stmt: "SELECT * FROM users", + }, + want: "Scan: users; projection=[]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, err := sqlparser.Parse(tt.args.stmt) + assert.NoError(t, err) + + q := &queryPlanner{} + got := q.ToPlan(ast) + explain := fmt.Sprintf(logical_plan.Format(got, 0)) + assert.Equal(t, tt.want, explain) + }) + } +} diff --git a/internal/engine/cost/query_planner/utils.go b/internal/engine/cost/query_planner/utils.go new file mode 100644 index 000000000..14ca24712 --- /dev/null +++ b/internal/engine/cost/query_planner/utils.go @@ -0,0 +1,246 @@ +package query_planner + +import ( + "fmt" + dt "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" + lp "github.com/kwilteam/kwil-db/internal/engine/cost/logical_plan" + pt "github.com/kwilteam/kwil-db/internal/engine/cost/plantree" + "slices" + "strings" +) + +func expandStar(schema *dt.Schema) []lp.LogicalExpr { + fmt.Println("=-------", schema) + // is there columns to skip? + var exprs []lp.LogicalExpr + for _, field := range schema.Fields { + // TODO: better way to get column expr? + exprs = append(exprs, &lp.ColumnExpr{Relation: field.Relation(), Name: field.Name}) + } + return exprs +} + +func expandQualifiedStar(schema *dt.Schema, table string) []lp.LogicalExpr { + panic("not implemented") +} + +// qualifyExpr returns a new expression qualified with the given relation. +// It won't change the original expression if it's not ColumnExpr. +// func qualifyExpr(expr lp.LogicalExpr, seen map[string] lp.LogicalExpr, schemas ...*dt.Schema) lp.LogicalExpr { +func qualifyExpr(expr lp.LogicalExpr, schemas ...*dt.Schema) lp.LogicalExpr { + c, ok := expr.(*lp.ColumnExpr) + if !ok { + return expr + } + + //// TODO: make all lp.LogicalExpr to implement pt.Node ? + //return c.TransformUp(func(n pt.Node) pt.Node { + // if c, ok := n.(*lp.ColumnExpr); ok { + // c.QualifyWithSchema(seen, schemas...) + // } + // return n + // + //}).(*lp.ColumnExpr) + + return c.QualifyWithSchemas(schemas...) +} + +// extractColumnsFromFilterExpr extracts the columns are references by the filter expression. +// It keeps track of the columns that have been seen in the 'seen' map. +// TODO: use visitor +func extractColumnsFromFilterExpr(expr lp.LogicalExpr, seen map[*lp.ColumnExpr]bool) { + switch e := expr.(type) { + case *lp.LiteralStringExpr: + case *lp.LiteralIntExpr: + case *lp.AliasExpr: + extractColumnsFromFilterExpr(e.Expr, seen) + case lp.UnaryExpr: + extractColumnsFromFilterExpr(e.E(), seen) + case lp.AggregateExpr: + extractColumnsFromFilterExpr(e.E(), seen) + case lp.BinaryExpr: + extractColumnsFromFilterExpr(e.L(), seen) + extractColumnsFromFilterExpr(e.R(), seen) + case *lp.ColumnExpr: + seen[e] = true + //case *.ColumnIdxExpr: + // seen[input.Schema().Fields[e.Idx].Name] = true + default: + panic(fmt.Sprintf("unknown expression type %T", e)) + } +} + +// extractAliases extracts the mapping of alias to its expression +func extractAliases(exprs []lp.LogicalExpr) map[string]lp.LogicalExpr { + aliases := make(map[string]lp.LogicalExpr) + for _, expr := range exprs { + if e, ok := expr.(*lp.AliasExpr); ok { + aliases[e.Alias] = e.Expr + } + } + return aliases +} + +func cloneAliases(aliases map[string]lp.LogicalExpr) map[string]lp.LogicalExpr { + clone := make(map[string]lp.LogicalExpr) + for k, v := range aliases { + clone[k] = v + } + return clone +} + +// resolveAliases resolves the expr to its un-aliased expression. +// It's used to resolve the alias in the select list to the actual expression. +func resolveAlias(expr lp.LogicalExpr, aliases map[string]lp.LogicalExpr) lp.LogicalExpr { + e := expr.TransformUp(func(n pt.TreeNode) pt.TreeNode { + if c, ok := n.(*lp.ColumnExpr); ok { + if e, ok := aliases[c.Name]; ok { + return e + } else { + return c + } + } + // otherwise, return the original node + return n + }) + + //_, e := pt.PostOrderApply(expr, func(n pt.TreeNode) (bool, any) { + // if e, ok := n.(*lp.ColumnExpr); ok { + // if e.Relation == nil { + // return true, nil + // } + // + // if aliasExpr, ok := aliases[e.Name]; ok { + // return true, aliasExpr + // } else { + // return true, e + // } + // } else { + // return true, n + // } + //}) + + return e.(lp.LogicalExpr) +} + +func extractAggrExprs(exprs []lp.LogicalExpr) []lp.LogicalExpr { + var aggrExprs []lp.LogicalExpr + for _, expr := range exprs { + if e, ok := expr.(lp.AggregateExpr); ok { + aggrExprs = append(aggrExprs, e) + } + } + return aggrExprs +} + +// allReferredColumns returns all the columns that are referenced by the expression. +func allReferredColumns(exprs []lp.LogicalExpr) []*lp.ColumnExpr { + var columns []*lp.ColumnExpr + for _, expr := range exprs { + pt.PreOrderApply(expr, func(n pt.TreeNode) (bool, any) { + if c, ok := n.(*lp.ColumnExpr); ok { + columns = append(columns, c) + } + return true, nil + }) + } + return columns +} + +// ensureSchemaSatifiesExprs ensures that the schema contains all the columns +// referenced by the expression. +func ensureSchemaSatifiesExprs(schema *dt.Schema, exprs []lp.LogicalExpr) error { + referredCols := allReferredColumns(exprs) + + for _, col := range referredCols { + if !schema.ContainsColumn(col.Relation, col.Name) { + return fmt.Errorf("column %s not found in schema", col.Name) + } + } + + return nil +} + +// rebaseExprs builds the expression on top of the base expressions. +// This is useful in the context of a query like: +// SELECT a + b < 1 ... GROUP BY a + b +func rebaseExprs(expr lp.LogicalExpr, baseExprs []lp.LogicalExpr, plan lp.LogicalPlan) lp.LogicalExpr { + return expr.TransformDown(func(n pt.TreeNode) pt.TreeNode { + contains := slices.ContainsFunc(baseExprs, func(e lp.LogicalExpr) bool { + // TODO: String() may not work + return e.String() == n.String() + }) + + if contains { + return exprAsColumn(n.(lp.LogicalExpr), plan) + } else { + return n + } + }).(lp.LogicalExpr) +} + +// checkExprsProjectFromColumns checks if the expression can be projected from the columns. +func checkExprsProjectFromColumns(exprs []lp.LogicalExpr, columns []lp.LogicalExpr) error { + for _, col := range columns { + if _, ok := col.(*lp.ColumnExpr); !ok { + return fmt.Errorf("expression %s is not a column", col.String()) + } + } + + colExprs := allReferredColumns(exprs) + for _, col := range colExprs { + if err := checkExprProjectFromColumns(col, columns); err != nil { + return err + } + } + return nil +} + +func checkExprProjectFromColumns(expr lp.LogicalExpr, columns []lp.LogicalExpr) error { + valid := slices.ContainsFunc(columns, func(c lp.LogicalExpr) bool { + return c.String() == expr.String() + }) + + if !valid { + return fmt.Errorf( + "expression %s cannot be resolved from available columns: %s", + expr.String(), columns) + } else { + return nil + } +} + +func exprAsColumn(expr lp.LogicalExpr, plan lp.LogicalPlan) *lp.ColumnExpr { + if c, ok := expr.(*lp.ColumnExpr); ok { + colDef := lp.ColumnFromExprToDef(c) + field := plan.Schema().FieldFromColumn(colDef) + return lp.ColumnFromDefToExpr(field.QualifiedColumn()) + } else { + // use the expression as the column name + // TODO: String() may not work + return lp.ColumnUnqualified(expr.String()) + } +} + +type TableRefName string + +func (t TableRefName) String() string { + return string(t) +} + +func (t TableRefName) Segments() []string { + return strings.Split(string(t), ".") +} + +func relationNameToTableRef(relationName string) (*dt.TableRef, error) { + tr := TableRefName(relationName) + segments := tr.Segments() + switch len(segments) { + case 1: + return &dt.TableRef{Table: segments[0]}, nil + case 2: + return &dt.TableRef{Schema: segments[0], Table: segments[1]}, nil + default: + return nil, fmt.Errorf("invalid relation name: %s", relationName) + } +} diff --git a/internal/engine/cost/virtual_plan/operator.go b/internal/engine/cost/virtual_plan/operator.go index 8e6db7275..a8fc1e7b5 100644 --- a/internal/engine/cost/virtual_plan/operator.go +++ b/internal/engine/cost/virtual_plan/operator.go @@ -2,6 +2,7 @@ package virtual_plan import ( "fmt" + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" ds "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" ) @@ -16,7 +17,7 @@ func (s *VScanOp) String() string { s.ds.Schema(), s.projection) } -func (s *VScanOp) Schema() *ds.Schema { +func (s *VScanOp) Schema() *datatypes.Schema { return s.ds.Schema().Select(s.projection...) } @@ -35,7 +36,7 @@ func VScan(ds ds.DataSource, projection ...string) VirtualPlan { type VProjectionOp struct { input VirtualPlan exprs []VirtualExpr - schema *ds.Schema + schema *datatypes.Schema } func (p *VProjectionOp) String() string { @@ -48,7 +49,7 @@ func (p *VProjectionOp) String() string { //return fmt.Sprintf("VProjection: %s", p.exprs) } -func (p *VProjectionOp) Schema() *ds.Schema { +func (p *VProjectionOp) Schema() *datatypes.Schema { return p.schema } @@ -80,7 +81,7 @@ func (p *VProjectionOp) Execute() *ds.Result { return ds.ResultFromStream(p.schema, out) } -func VProjection(input VirtualPlan, schema *ds.Schema, exprs ...VirtualExpr) VirtualPlan { +func VProjection(input VirtualPlan, schema *datatypes.Schema, exprs ...VirtualExpr) VirtualPlan { return &VProjectionOp{input: input, exprs: exprs, schema: schema} } @@ -94,7 +95,7 @@ func (s *VSelectionOp) String() string { //return fmt.Sprintf("VSelection: %s", s.expr) } -func (s *VSelectionOp) Schema() *ds.Schema { +func (s *VSelectionOp) Schema() *datatypes.Schema { return s.input.Schema() } diff --git a/internal/engine/cost/virtual_plan/plan.go b/internal/engine/cost/virtual_plan/plan.go index 95a86a534..58f0c9aaf 100644 --- a/internal/engine/cost/virtual_plan/plan.go +++ b/internal/engine/cost/virtual_plan/plan.go @@ -3,6 +3,7 @@ package virtual_plan import ( "bytes" "fmt" + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" ) @@ -11,7 +12,7 @@ type VirtualPlan interface { fmt.Stringer // Schema returns the schema of the data that will be produced by this VirtualPlan. - Schema() *datasource.Schema + Schema() *datatypes.Schema Inputs() []VirtualPlan Execute() *datasource.Result diff --git a/internal/engine/cost/virtual_plan/planner.go b/internal/engine/cost/virtual_plan/planner.go index 2e518c26e..628170cb1 100644 --- a/internal/engine/cost/virtual_plan/planner.go +++ b/internal/engine/cost/virtual_plan/planner.go @@ -2,8 +2,8 @@ package virtual_plan import ( "fmt" + "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes" - "github.com/kwilteam/kwil-db/internal/engine/cost/datasource" "github.com/kwilteam/kwil-db/internal/engine/cost/logical_plan" ) @@ -32,11 +32,11 @@ func (q *defaultVirtualPlanner) ToPlan(logicalPlan logical_plan.LogicalPlan) Vir for _, expr := range p.Exprs() { selectExprs = append(selectExprs, q.ToExpr(expr, p.Inputs()[0])) } - projectedFields := make([]datasource.Field, 0, len(selectExprs)) + projectedFields := make([]datatypes.Field, 0, len(selectExprs)) for _, expr := range p.Exprs() { projectedFields = append(projectedFields, expr.Resolve(p.Inputs()[0])) } - projectedSchema := datasource.NewSchema(projectedFields...) + projectedSchema := datatypes.NewSchema(projectedFields...) return VProjection(input, projectedSchema, selectExprs...) case *logical_plan.SelectionOp: input := q.ToPlan(p.Inputs()[0])