Skip to content

Commit

Permalink
support psql
Browse files Browse the repository at this point in the history
  • Loading branch information
zyxkad committed Mar 22, 2024
1 parent 716b83f commit fdeb5b7
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 18 deletions.
226 changes: 210 additions & 16 deletions database/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ import (
)

type SqlDB struct {
db *sql.DB
driverName string
db *sql.DB

jtiStmts struct {
get *sql.Stmt
add *sql.Stmt
remove *sql.Stmt
clean *sql.Stmt
}

fileRecordStmts struct {
Expand Down Expand Up @@ -71,7 +73,8 @@ func NewSqlDB(driverName string, dataSourceName string) (db *SqlDB, err error) {
ddb.SetMaxIdleConns(16)

db = &SqlDB{
db: ddb,
driverName: driverName,
db: ddb,
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
Expand Down Expand Up @@ -108,6 +111,34 @@ func (db *SqlDB) Cleanup() (err error) {
}

func (db *SqlDB) setupJTI(ctx context.Context) (err error) {
switch db.driverName {
case "sqlite", "mysql":
err = db.setupJTIQuestionMark(ctx)
case "postgres":
err = db.setupJTIDollarMark(ctx)
default:
panic("Unknown sql drive " + db.driverName)
}
if err != nil {
return
}

db.jtiCleaner = time.NewTimer(time.Minute * 10)
go func(timer *time.Timer, cleanStmt *sql.Stmt) {
defer cleanStmt.Close()
for range timer.C {
ctx, cancel := context.WithTimeout(ctx, time.Second*15)
_, err := cleanStmt.ExecContext(ctx)
cancel()
if err != nil {
log.Errorf("Error when cleaning expired tokens: %v", err)
}
}
}(db.jtiCleaner, db.jtiStmts.clean)
return
}

func (db *SqlDB) setupJTIQuestionMark(ctx context.Context) (err error) {
const tableName = "`token_id`"

const createTable = "CREATE TABLE IF NOT EXISTS " + tableName + " (" +
Expand Down Expand Up @@ -140,23 +171,48 @@ func (db *SqlDB) setupJTI(ctx context.Context) (err error) {

const cleanDeleteCmd = "DELETE FROM " + tableName +
" WHERE `expire` < CURRENT_TIMESTAMP"
var cleanStmt *sql.Stmt
if cleanStmt, err = db.db.PrepareContext(ctx, cleanDeleteCmd); err != nil {
if db.jtiStmts.clean, err = db.db.PrepareContext(ctx, cleanDeleteCmd); err != nil {
return
}
return
}

db.jtiCleaner = time.NewTimer(time.Minute * 10)
go func(timer *time.Timer, cleanStmt *sql.Stmt) {
defer cleanStmt.Close()
for range timer.C {
ctx, cancel := context.WithTimeout(ctx, time.Second*15)
_, err := cleanStmt.ExecContext(ctx)
cancel()
if err != nil {
log.Errorf("Error when cleaning expired tokens: %v", err)
}
}
}(db.jtiCleaner, cleanStmt)
func (db *SqlDB) setupJTIDollarMark(ctx context.Context) (err error) {
const tableName = "token_id"

const createTable = "CREATE TABLE IF NOT EXISTS " + tableName + " (" +
" id VARCHAR(127) NOT NULL," +
" expire TIMESTAMP NOT NULL," +
" PRIMARY KEY (id)" +
")"
if _, err = db.db.ExecContext(ctx, createTable); err != nil {
return
}

const getSelectCmd = "SELECT 1 FROM " + tableName +
" WHERE id=$1 AND expire > CURRENT_TIMESTAMP"
if db.jtiStmts.get, err = db.db.PrepareContext(ctx, getSelectCmd); err != nil {
return
}

const addInsertCmd = "INSERT INTO " + tableName +
" (id,expire) VALUES" +
" ($1,$2)"
if db.jtiStmts.add, err = db.db.PrepareContext(ctx, addInsertCmd); err != nil {
return
}

const removeDeleteCmd = "DELETE FROM " + tableName +
" WHERE id=$1"
if db.jtiStmts.remove, err = db.db.PrepareContext(ctx, removeDeleteCmd); err != nil {
return
}

const cleanDeleteCmd = "DELETE FROM " + tableName +
" WHERE expire < CURRENT_TIMESTAMP"
if db.jtiStmts.clean, err = db.db.PrepareContext(ctx, cleanDeleteCmd); err != nil {
return
}
return
}

Expand Down Expand Up @@ -198,6 +254,17 @@ func (db *SqlDB) RemoveJTI(jti string) (err error) {
}

func (db *SqlDB) setupFileRecords(ctx context.Context) (err error) {
switch db.driverName {
case "sqlite", "mysql":
return db.setupFileRecordsQuestionMark(ctx)
case "postgres":
return db.setupFileRecordsDollarMark(ctx)
default:
panic("Unknown sql drive " + db.driverName)
}
}

func (db *SqlDB) setupFileRecordsQuestionMark(ctx context.Context) (err error) {
const tableName = "`file_records`"

const createTable = "CREATE TABLE IF NOT EXISTS " + tableName + " (" +
Expand Down Expand Up @@ -248,6 +315,57 @@ func (db *SqlDB) setupFileRecords(ctx context.Context) (err error) {
return err
}

func (db *SqlDB) setupFileRecordsDollarMark(ctx context.Context) (err error) {
const tableName = "file_records"

const createTable = "CREATE TABLE IF NOT EXISTS " + tableName + " (" +
" path VARCHAR(255) NOT NULL," +
" hash VARCHAR(255) NOT NULL," +
" size INTEGER NOT NULL," +
" PRIMARY KEY (path)" +
")"
if _, err = db.db.ExecContext(ctx, createTable); err != nil {
return
}

const getSelectCmd = "SELECT hash,size FROM " + tableName +
" WHERE path=$1"
if db.fileRecordStmts.get, err = db.db.PrepareContext(ctx, getSelectCmd); err != nil {
return
}

const hasSelectCmd = "SELECT 1 FROM " + tableName +
" WHERE path=$1"
if db.fileRecordStmts.has, err = db.db.PrepareContext(ctx, hasSelectCmd); err != nil {
return
}

const setInsertCmd = "INSERT INTO " + tableName +
" (path,hash,size) VALUES" +
" ($1,$2,$3)"
const setUpdateCmd = "UPDATE " + tableName + " SET" +
" hash=$1, size=$2" +
" WHERE path=$3"
if db.fileRecordStmts.setInsert, err = db.db.PrepareContext(ctx, setInsertCmd); err != nil {
return
}
if db.fileRecordStmts.setUpdate, err = db.db.PrepareContext(ctx, setUpdateCmd); err != nil {
return
}

const removeDeleteCmd = "DELETE FROM " + tableName +
" WHERE path=$1"
if db.fileRecordStmts.remove, err = db.db.PrepareContext(ctx, removeDeleteCmd); err != nil {
return
}

const forEachSelectCmd = "SELECT path,hash,size FROM " + tableName
if db.fileRecordStmts.forEach, err = db.db.PrepareContext(ctx, forEachSelectCmd); err != nil {
return
}
return err
}

func (db *SqlDB) GetFileRecord(path string) (rec *FileRecord, err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
Expand Down Expand Up @@ -332,6 +450,17 @@ func (db *SqlDB) ForEachFileRecord(cb func(*FileRecord) error) (err error) {
}

func (db *SqlDB) setupSubscribe(ctx context.Context) (err error) {
switch db.driverName {
case "sqlite", "mysql":
return db.setupSubscribeQuestionMark(ctx)
case "postgres":
return db.setupSubscribeDollarMark(ctx)
default:
panic("Unknown sql drive " + db.driverName)
}
}

func (db *SqlDB) setupSubscribeQuestionMark(ctx context.Context) (err error) {
const tableName = "`subscribes`"

const createTable = "CREATE TABLE IF NOT EXISTS " + tableName + " (" +
Expand Down Expand Up @@ -396,6 +525,71 @@ func (db *SqlDB) setupSubscribe(ctx context.Context) (err error) {
return err
}

func (db *SqlDB) setupSubscribeDollarMark(ctx context.Context) (err error) {
const tableName = "subscribes"

const createTable = "CREATE TABLE IF NOT EXISTS " + tableName + " (" +
` "user" VARCHAR(127) NOT NULL,` +
" client VARCHAR(127) NOT NULL," +
" endpoint VARCHAR(255) NOT NULL," +
" keys VARCHAR(255) NOT NULL," +
" scopes INTEGER NOT NULL," +
` PRIMARY KEY ("user",client)` +
")"
if _, err = db.db.ExecContext(ctx, createTable); err != nil {
return
}

const getSelectCmd = "SELECT endpoint,keys,scopes FROM " + tableName +
` WHERE "user"=$1 AND client=$2`
if db.subscribeStmts.get, err = db.db.PrepareContext(ctx, getSelectCmd); err != nil {
return
}

const hasSelectCmd = "SELECT 1 FROM " + tableName +
` WHERE "user"=$1 AND client=$2`
if db.subscribeStmts.has, err = db.db.PrepareContext(ctx, hasSelectCmd); err != nil {
return
}

const setInsertCmd = "INSERT INTO " + tableName +
` ("user",client,endpoint,keys,scopes) VALUES` +
" ($1,$2,$3,$4,$5)"
const setUpdateCmd = "UPDATE " + tableName + " SET" +
" endpoint=$1, keys=$2, scopes=$3" +
` WHERE "user"=$4 AND client=$5`
const setUpdateScopesOnlyCmd = "UPDATE " + tableName + " SET" +
" scopes=$1" +
` WHERE "user"=$2 AND client=$3`
if db.subscribeStmts.setInsert, err = db.db.PrepareContext(ctx, setInsertCmd); err != nil {
return
}
if db.subscribeStmts.setUpdate, err = db.db.PrepareContext(ctx, setUpdateCmd); err != nil {
return
}
if db.subscribeStmts.setUpdateScopesOnly, err = db.db.PrepareContext(ctx, setUpdateScopesOnlyCmd); err != nil {
return
}

const removeDeleteCmd = "DELETE FROM " + tableName +
` WHERE "user"=$1 AND client=$2`
if db.subscribeStmts.remove, err = db.db.PrepareContext(ctx, removeDeleteCmd); err != nil {
return
}

const removeUserDeleteCmd = "DELETE FROM " + tableName +
` WHERE "user"=$1`
if db.subscribeStmts.removeUser, err = db.db.PrepareContext(ctx, removeUserDeleteCmd); err != nil {
return
}

const forEachSelectCmd = `SELECT "user",client,endpoint,keys,scopes FROM ` + tableName
if db.subscribeStmts.forEach, err = db.db.PrepareContext(ctx, forEachSelectCmd); err != nil {
return
}
return err
}

func (db *SqlDB) GetSubscribe(user string, client string) (rec *SubscribeRecord, err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
Expand Down
2 changes: 1 addition & 1 deletion database/sql_drives.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ package database
import (
_ "github.com/glebarez/go-sqlite" // sqlite
_ "github.com/go-sql-driver/mysql" // mysql
// _ "github.com/lib/pq" // postgres // TODO: query with $1, $2 ... format
_ "github.com/lib/pq" // postgres
)
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,17 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int {
dumpFile.Close()
dumpCommand = (string)(bytes.TrimSpace(buf[:n]))
}
pcmd := pprof.Lookup(dumpCommand)
if pcmd == nil {
log.Errorf("No pprof command is named %q", dumpCommand)
continue
}
name := fmt.Sprintf(time.Now().Format("dump-%s-20060102-150405.txt"), dumpCommand)
log.Infof("Creating goroutine dump file at %s", name)
if fd, err := os.Create(name); err != nil {
log.Infof("Cannot create dump file: %v", err)
} else {
err := pprof.Lookup(dumpCommand).WriteTo(fd, 1)
err := pcmd.WriteTo(fd, 1)
fd.Close()
if err != nil {
log.Infof("Cannot write dump file: %v", err)
Expand Down

0 comments on commit fdeb5b7

Please sign in to comment.