Skip to content

Commit 2526fc8

Browse files
committed
Transitive closure virtual table.
1 parent d737620 commit 2526fc8

File tree

4 files changed

+430
-3
lines changed

4 files changed

+430
-3
lines changed

ext/closure/closure.go

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
// Package closure provides a transitive closure virtual table.
2+
//
3+
// The "transitive_closure" virtual table finds the transitive closure of
4+
// a parent/child relationship in a real table.
5+
//
6+
// https://sqlite.org/src/doc/tip/ext/misc/closure.c
7+
package closure
8+
9+
import (
10+
"fmt"
11+
"math"
12+
13+
"github.com/ncruces/go-sqlite3"
14+
"github.com/ncruces/go-sqlite3/internal/util"
15+
"github.com/ncruces/go-sqlite3/util/vtabutil"
16+
)
17+
18+
const (
19+
_COL_ID = 0
20+
_COL_DEPTH = 1
21+
_COL_ROOT = 2
22+
_COL_TABLENAME = 3
23+
_COL_IDCOLUMN = 4
24+
_COL_PARENTCOLUMN = 5
25+
)
26+
27+
func Register(db *sqlite3.Conn) error {
28+
return sqlite3.CreateModule(db, "transitive_closure", nil,
29+
func(db *sqlite3.Conn, _, _, _ string, arg ...string) (*closure, error) {
30+
var (
31+
table string
32+
column string
33+
parent string
34+
35+
done = util.Set[string]{}
36+
)
37+
38+
for _, arg := range arg {
39+
key, val := vtabutil.NamedArg(arg)
40+
if done.Contains(key) {
41+
return nil, fmt.Errorf("transitive_closure: more than one %q parameter", key)
42+
}
43+
switch key {
44+
case "tablename":
45+
table = vtabutil.Unquote(val)
46+
case "idcolumn":
47+
column = vtabutil.Unquote(val)
48+
case "parentcolumn":
49+
parent = vtabutil.Unquote(val)
50+
default:
51+
return nil, fmt.Errorf("transitive_closure: unknown %q parameter", key)
52+
}
53+
done.Add(key)
54+
}
55+
56+
err := db.DeclareVTab(`CREATE TABLE x(id,depth,root HIDDEN,tablename HIDDEN,idcolumn HIDDEN,parentcolumn HIDDEN)`)
57+
if err != nil {
58+
return nil, err
59+
}
60+
return &closure{
61+
db: db,
62+
table: table,
63+
column: column,
64+
parent: parent,
65+
}, nil
66+
})
67+
}
68+
69+
type closure struct {
70+
db *sqlite3.Conn
71+
table string
72+
column string
73+
parent string
74+
}
75+
76+
func (c *closure) Destroy() error { return nil }
77+
78+
func (c *closure) BestIndex(idx *sqlite3.IndexInfo) error {
79+
posi := 1
80+
plan := 0
81+
cost := 10000000.0
82+
83+
for i, cst := range idx.Constraint {
84+
if !cst.Usable {
85+
continue
86+
}
87+
if plan&1 == 0 && cst.Column == _COL_ROOT {
88+
switch cst.Op {
89+
case sqlite3.INDEX_CONSTRAINT_EQ:
90+
plan |= 1
91+
cost /= 100
92+
idx.ConstraintUsage[i].ArgvIndex = 1
93+
idx.ConstraintUsage[i].Omit = true
94+
}
95+
continue
96+
}
97+
if plan&0xf0 == 0 && cst.Column == _COL_DEPTH {
98+
switch cst.Op {
99+
case sqlite3.INDEX_CONSTRAINT_LT, sqlite3.INDEX_CONSTRAINT_LE, sqlite3.INDEX_CONSTRAINT_EQ:
100+
plan |= posi << 4
101+
cost /= 5
102+
posi += 1
103+
idx.ConstraintUsage[i].ArgvIndex = posi
104+
if cst.Op == sqlite3.INDEX_CONSTRAINT_LT {
105+
plan |= 2
106+
}
107+
}
108+
continue
109+
}
110+
if plan&0xf00 == 0 && cst.Column == _COL_TABLENAME {
111+
switch cst.Op {
112+
case sqlite3.INDEX_CONSTRAINT_EQ:
113+
plan |= posi << 8
114+
cost /= 5
115+
posi += 1
116+
idx.ConstraintUsage[i].ArgvIndex = posi
117+
idx.ConstraintUsage[i].Omit = true
118+
}
119+
continue
120+
}
121+
if plan&0xf000 == 0 && cst.Column == _COL_IDCOLUMN {
122+
switch cst.Op {
123+
case sqlite3.INDEX_CONSTRAINT_EQ:
124+
plan |= posi << 12
125+
posi += 1
126+
idx.ConstraintUsage[i].ArgvIndex = posi
127+
idx.ConstraintUsage[i].Omit = true
128+
}
129+
continue
130+
}
131+
if plan&0xf0000 == 0 && cst.Column == _COL_PARENTCOLUMN {
132+
switch cst.Op {
133+
case sqlite3.INDEX_CONSTRAINT_EQ:
134+
plan |= posi << 16
135+
posi += 1
136+
idx.ConstraintUsage[i].ArgvIndex = posi
137+
idx.ConstraintUsage[i].Omit = true
138+
}
139+
continue
140+
}
141+
}
142+
143+
if c.table == "" && plan&0xf00 == 0 ||
144+
c.column == "" && plan&0xf000 == 0 ||
145+
c.parent == "" && plan&0xf0000 == 0 {
146+
plan = 0
147+
}
148+
if plan&1 == 0 {
149+
plan = 0
150+
cost *= 1e30
151+
for i := range idx.Constraint {
152+
idx.ConstraintUsage[i].ArgvIndex = 0
153+
}
154+
}
155+
156+
idx.EstimatedCost = cost
157+
idx.IdxNum = plan
158+
return nil
159+
}
160+
161+
func (c *closure) Open() (sqlite3.VTabCursor, error) {
162+
return &cursor{closure: c}, nil
163+
}
164+
165+
type cursor struct {
166+
*closure
167+
nodes []node
168+
}
169+
170+
type node struct {
171+
id int64
172+
depth int
173+
}
174+
175+
func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
176+
if idxNum&1 == 0 {
177+
return nil
178+
}
179+
180+
root := arg[0].Int64()
181+
maxDepth := math.MaxInt
182+
if idxNum&0xf0 != 0 {
183+
maxDepth = arg[(idxNum>>4)&0xf].Int()
184+
if idxNum&2 != 0 {
185+
maxDepth -= 1
186+
}
187+
}
188+
table := c.table
189+
if idxNum&0xf00 != 0 {
190+
table = arg[(idxNum>>8)&0xf].Text()
191+
}
192+
column := c.column
193+
if idxNum&0xf000 != 0 {
194+
column = arg[(idxNum>>12)&0xf].Text()
195+
}
196+
parent := c.parent
197+
if idxNum&0xf0000 != 0 {
198+
parent = arg[(idxNum>>16)&0xf].Text()
199+
}
200+
201+
sql := fmt.Sprintf(
202+
`SELECT %[1]s.%[2]s FROM %[1]s WHERE %[1]s.%[3]s=?`,
203+
sqlite3.QuoteIdentifier(table),
204+
sqlite3.QuoteIdentifier(column),
205+
sqlite3.QuoteIdentifier(parent),
206+
)
207+
stmt, _, err := c.db.Prepare(sql)
208+
if err != nil {
209+
return err
210+
}
211+
defer stmt.Close()
212+
213+
c.nodes = []node{{root, 0}}
214+
set := util.Set[int64]{}
215+
set.Add(root)
216+
for i := 0; i < len(c.nodes); i++ {
217+
curr := c.nodes[i]
218+
if curr.depth >= maxDepth {
219+
continue
220+
}
221+
stmt.BindInt64(1, curr.id)
222+
for stmt.Step() {
223+
if stmt.ColumnType(0) == sqlite3.INTEGER {
224+
next := stmt.ColumnInt64(0)
225+
if !set.Contains(next) {
226+
set.Add(next)
227+
c.nodes = append(c.nodes, node{next, curr.depth + 1})
228+
}
229+
}
230+
}
231+
stmt.Reset()
232+
}
233+
return nil
234+
}
235+
236+
func (c *cursor) Column(ctx sqlite3.Context, n int) error {
237+
switch n {
238+
case _COL_ID:
239+
ctx.ResultInt64(c.nodes[0].id)
240+
case _COL_DEPTH:
241+
ctx.ResultInt(c.nodes[0].depth)
242+
case _COL_TABLENAME:
243+
ctx.ResultText(c.table)
244+
case _COL_IDCOLUMN:
245+
ctx.ResultText(c.column)
246+
case _COL_PARENTCOLUMN:
247+
ctx.ResultText(c.parent)
248+
}
249+
return nil
250+
}
251+
252+
func (c *cursor) Next() error {
253+
c.nodes = c.nodes[1:]
254+
return nil
255+
}
256+
257+
func (c *cursor) EOF() bool {
258+
return len(c.nodes) == 0
259+
}
260+
261+
func (c *cursor) RowID() (int64, error) {
262+
return c.nodes[0].id, nil
263+
}

0 commit comments

Comments
 (0)