Skip to content

Commit

Permalink
Merge pull request #7 from mildred/authorizer
Browse files Browse the repository at this point in the history
Implement authorizer and add ability to run dynamic SQL statements
  • Loading branch information
codehz authored Aug 11, 2023
2 parents b3a6871 + 267bb79 commit 6f175f4
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/easy_sqlite3.nim
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import easy_sqlite3/[bindings,macros]
export macros

export raw, len, toOpenArray, SQLiteError, SQLiteBlob, Statement, Database, OpenFlag, enableSharedCache, initDatabase, exec, execM, changes, lastInsertRowid, `[]=`, reset, step, withColumnBlob, getParameterIndex, getColumnType, getColumn, unpack, `=destroy`
export raw, len, toOpenArray, SQLiteError, SQLiteBlob, Statement, Database,
SqliteDataType, OpenFlag, enableSharedCache, initDatabase, exec, execM,
changes, lastInsertRowid, `[]=`, reset, step, withColumnBlob,
getParameterIndex, getColumnType, getColumn, ColumnDef, columns, `[]`,
unpack, `=destroy`, newStatement, rows, setAuthorizer,
AuthorizerActionCode, AuthorizerRequest, AuthorizerResult, RawAuthorizer,
Authorizer
238 changes: 235 additions & 3 deletions src/easy_sqlite3/bindings.nim
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,102 @@ func compileTimeHash[T](original: static[T]): CachedHash[T] =
type Statement* = object
raw*: ptr RawStatement

type
AuthorizerResult* {.pure, size: sizeof(cint).} = enum
ok = 0,
deny = 1,
ignore = 2
AuthorizerActionCode* {.pure, size: sizeof(cint).} = enum
# action code # arg3 arg4
copy = 0, # No longer used
create_index = 1, # Index Name Table Name
create_table = 2, # Table Name NULL
create_temp_index = 3, # Index Name Table Name
create_temp_table = 4, # Table Name NULL
create_temp_trigger = 5, # Trigger Name Table Name
create_temp_view = 6, # View Name NULL
create_trigger = 7, # Trigger Name Table Name
create_view = 8, # View Name NULL
delete = 9, # Table Name NULL
drop_index = 10, # Index Name Table Name
drop_table = 11, # Table Name NULL
drop_temp_index = 12, # Index Name Table Name
drop_temp_table = 13, # Table Name NULL
drop_temp_trigger = 14, # Trigger Name Table Name
drop_temp_view = 15, # View Name NULL
drop_trigger = 16, # Trigger Name Table Name
drop_view = 17, # View Name NULL
insert = 18, # Table Name NULL
pragma = 19, # Pragma Name 1st arg or NULL
read = 20, # Table Name Column Name
select = 21, # NULL NULL
transaction = 22, # Operation NULL
update = 23, # Table Name Column Name
attach = 24, # Filename NULL
detach = 25, # Database Name NULL
alter_table = 26, # Database Name Table Name
reindex = 27, # Index Name NULL
analyze = 28, # Table Name NULL
create_vtable = 29, # Table Name Module Name
drop_vtable = 30, # Table Name Module Name
function = 31, # NULL Function Name
savepoint = 32, # Operation Savepoint Name
recursive = 33, # NULL NULL
AuthorizerRequest* = ref object
case action_code*: AuthorizerActionCode
of create_index, create_temp_index, drop_index, drop_temp_index:
index_name*: string
index_table_name*: string
of create_table, create_temp_table, delete, drop_table, drop_temp_table, insert, analyze:
table_name*: string
of create_temp_trigger, create_trigger, drop_temp_trigger, drop_trigger:
trigger_name*: string
trigger_table_name*: string
of create_temp_view, create_view, drop_temp_view, drop_view:
view_name*: string
of pragma:
pragma_name*: string
pragma_arg*: Option[string]
of read, update:
target_table_name*: string
column_name*: string
of select, recursive, copy:
discard
of transaction:
transaction_operation*: string
of attach:
filename*: string
of detach:
database_name*: string
of alter_table:
alter_database_name*: string
alter_table_name*: string
of reindex:
reindex_index_name*: string
of create_vtable, drop_vtable:
vtable_name*: string
module_name*: string
of function:
# no arg3
function_name*: string
of savepoint:
savepoint_operation*: string
savepoint_name*: string
SqliteRawAuthorizer* = proc(
userdata: pointer,
action_code: AuthorizerActionCode,
arg3, arg4, arg5, arg6: cstring): AuthorizerResult {.cdecl.}
RawAuthorizer* = proc(
action_code: AuthorizerActionCode,
arg3, arg4, arg5, arg6: Option[string]): AuthorizerResult
Authorizer* = proc(request: AuthorizerRequest): AuthorizerResult
WrapAuthorizer = object
authorizer: RawAuthorizer

type Database* = object
raw*: ptr RawDatabase
stmtcache: Table[CachedHash[string], ref Statement]
authorizer: ref WrapAuthorizer

type ResultCode* {.pure.} = enum
sr_ok = 0,
Expand Down Expand Up @@ -224,7 +317,7 @@ type SqliteDestroctor* = proc (p: pointer) {.cdecl.}
const StaticDestructor* = cast[SqliteDestroctor](0)
const TransientDestructor* = cast[SqliteDestroctor](-1)

type SqliteDateType* = enum
type SqliteDataType* = enum
dt_integer = 1,
dt_float = 2,
dt_text = 3,
Expand Down Expand Up @@ -374,6 +467,7 @@ proc sqlite3_prepare_v3*(db: ptr RawDatabase, sql: cstring, nbyte: int, flags: P
proc sqlite3_finalize*(st: ptr RawStatement): ResultCode {.sqlite3linkage.}
proc sqlite3_reset*(st: ptr RawStatement): ResultCode {.sqlite3linkage.}
proc sqlite3_step*(st: ptr RawStatement): ResultCode {.sqlite3linkage.}
proc sqlite3_set_authorizer*(db: ptr RawDatabase, auth: SqliteRawAuthorizer, userdata: pointer): ResultCode {.sqlite3linkage.}
proc sqlite3_bind_parameter_index*(st: ptr RawStatement, name: cstring): int {.sqlite3linkage.}
proc sqlite3_bind_blob64*(st: ptr RawStatement, idx: int, buffer: pointer, len: int, free: SqliteDestroctor): ResultCode {.sqlite3linkage.}
proc sqlite3_bind_double*(st: ptr RawStatement, idx: int, value: float64): ResultCode {.sqlite3linkage.}
Expand All @@ -385,7 +479,9 @@ proc sqlite3_bind_pointer*(st: ptr RawStatement, idx: int, val: pointer, name: c
proc sqlite3_bind_zeroblob64*(st: ptr RawStatement, idx: int, len: int): ResultCode {.sqlite3linkage.}
proc sqlite3_changes*(st: ptr RawDatabase): int {.sqlite3linkage.}
proc sqlite3_last_insert_rowid*(st: ptr RawDatabase): int {.sqlite3linkage.}
proc sqlite3_column_type*(st: ptr RawStatement, idx: int): SqliteDateType {.sqlite3linkage.}
proc sqlite3_column_count*(st: ptr RawStatement): int {.sqlite3linkage.}
proc sqlite3_column_type*(st: ptr RawStatement, idx: int): SqliteDataType {.sqlite3linkage.}
proc sqlite3_column_name*(st: ptr RawStatement, idx: int): cstring {.sqlite3linkage.}
proc sqlite3_column_blob*(st: ptr RawStatement, idx: int): pointer {.sqlite3linkage.}
proc sqlite3_column_bytes*(st: ptr RawStatement, idx: int): int {.sqlite3linkage.}
proc sqlite3_column_double*(st: ptr RawStatement, idx: int): float64 {.sqlite3linkage.}
Expand Down Expand Up @@ -442,6 +538,107 @@ proc initDatabase*(
sqliteCheck sqlite3_open_v2(filename, addr result.raw, flags, vfs)
result.stmtcache = initTable[CachedHash[string], ref Statement]()

proc toS(s: cstring): Option[string] =
if s == nil:
result = none(string)
else:
result = some($s)

proc setAuthorizer*(db: var Database, callback: RawAuthorizer = nil) =
let userdata: ref WrapAuthorizer = new(WrapAuthorizer)
userdata.authorizer = callback

proc raw_callback(
userdata: pointer,
action_code: AuthorizerActionCode,
arg3, arg4, arg5, arg6: cstring): AuthorizerResult {.cdecl.} =
let callback = cast[ref WrapAuthorizer](userdata).authorizer
callback(action_code, arg3.toS(), arg4.toS(), arg5.toS(), arg6.toS())

var res: ResultCode
if callback == nil:
res = db.raw.sqlite3_set_authorizer(nil, nil)
else:
res = db.raw.sqlite3_set_authorizer(raw_callback, cast[pointer](userdata))
db.authorizer = userdata
if res != ResultCode.sr_ok:
raise newSQLiteError res

proc setAuthorizer*(db: var Database, callback: Authorizer = nil) =
var raw_callback: RawAuthorizer = nil
if callback != nil:
raw_callback = proc(code: AuthorizerActionCode, arg3, arg4, arg5, arg6: Option[string]): AuthorizerResult =
var req: AuthorizerRequest
case code
of create_index, create_temp_index, drop_index, drop_temp_index:
req = AuthorizerRequest(
action_code: code,
index_name: arg3.get,
index_table_name: arg4.get)
of create_table, create_temp_table, delete, drop_table, drop_temp_table, insert, analyze:
req = AuthorizerRequest(
action_code: code,
table_name: arg3.get)
of create_temp_trigger, create_trigger, drop_temp_trigger, drop_trigger:
req = AuthorizerRequest(
action_code: code,
trigger_name: arg3.get,
trigger_table_name: arg4.get)
of create_temp_view, create_view, drop_temp_view, drop_view:
req = AuthorizerRequest(
action_code: code,
view_name: arg3.get)
of pragma:
req = AuthorizerRequest(
action_code: code,
pragma_name: arg3.get,
pragma_arg: arg4)
of read, update:
req = AuthorizerRequest(
action_code: code,
target_table_name: arg3.get,
column_name: arg4.get)
of select, recursive, copy:
req = AuthorizerRequest(action_code: code)
of transaction:
req = AuthorizerRequest(
action_code: code,
transaction_operation: arg3.get)
of attach:
req = AuthorizerRequest(
action_code: code,
filename: arg3.get)
of detach:
req = AuthorizerRequest(
action_code: code,
database_name: arg3.get)
of alter_table:
req = AuthorizerRequest(
action_code: code,
alter_database_name: arg3.get,
alter_table_name: arg4.get)
of reindex:
req = AuthorizerRequest(
action_code: code,
reindex_index_name: arg3.get)
of create_vtable, drop_vtable:
req = AuthorizerRequest(
action_code: code,
vtable_name: arg3.get,
module_name: arg4.get)
of function:
req = AuthorizerRequest(
action_code: code,
# no arg3
function_name: arg4.get)
of savepoint:
req = AuthorizerRequest(
action_code: code,
savepoint_operation: arg3.get,
savepoint_name: arg4.get)
return callback(req)
db.setAuthorizer(raw_callback)

proc changes*(st: var Database): int =
sqlite3_changes st.raw

Expand Down Expand Up @@ -497,6 +694,9 @@ proc `[]=`*[T](st: ref Statement, idx: int, val: Option[T]) =
else:
st[idx] = val.get

proc `[]=`*[T](st: ref Statement, name: string, value: T) =
st[st.getParameterIndex(name)] = value

proc reset*(st: ref Statement) =
st.raw.sqliteCheck sqlite3_reset(st.raw)

Expand All @@ -515,7 +715,7 @@ proc withColumnBlob*(st: ref Statement, idx: int, recv: proc(vm: openarray[byte]
let l = sqlite3_column_bytes(st.raw, idx)
recv(cast[ptr UncheckedArray[byte]](p).toOpenArray(0, l))

proc getColumnType*(st: ref Statement, idx: int): SqliteDateType =
proc getColumnType*(st: ref Statement, idx: int): SqliteDataType =
sqlite3_column_type(st.raw, idx)

proc getColumn*(st: ref Statement, idx: int, T: typedesc[seq[byte]]): seq[byte] =
Expand All @@ -542,6 +742,31 @@ proc getColumn*[T](st: ref Statement, idx: int, _: typedesc[Option[T]]): Option[
else:
some(st.getColumn(idx, T))

type ColumnDef* = object
st*: ref Statement
idx*: int
data_type*: SqliteDataType
name*: string

proc columns*(st: ref Statement): seq[ref ColumnDef] =
result = @[]
var idx = 0
let count = sqlite3_column_count(st.raw)
while idx < count:
let col = new(ColumnDef)
col.st = st
col.idx = idx
col.data_type = sqlite3_column_type(st.raw, idx)
col.name = $sqlite3_column_name(st.raw, idx)
result.add(col)
idx += 1

proc `[]`*(st: ref Statement, idx: int): ref ColumnDef =
result = st.columns[idx]

proc `[]`*[T](col: ref ColumnDef, t: typedesc[T]): T =
result = col.st.getColumn(col.idx, t)

proc unpack*[T: tuple](st: ref Statement, _: typedesc[T]): T =
var idx = 0
for value in result.fields:
Expand All @@ -568,3 +793,10 @@ proc execM*(db: var Database, sqls: varargs[string]) {.discardable.} =
except CatchableError:
discard db.exec "ROLLBACK"
raise getCurrentException()

iterator rows*(st: ref Statement): seq[ref ColumnDef] =
try:
while st.step():
yield st.columns()
finally:
st.reset()

0 comments on commit 6f175f4

Please sign in to comment.