Skip to content

Commit

Permalink
Implement SQL permissions
Browse files Browse the repository at this point in the history
Signed-off-by: Stefano Scafiti <[email protected]>
  • Loading branch information
ostafen committed Jul 24, 2024
1 parent 3eec9d4 commit c189283
Show file tree
Hide file tree
Showing 23 changed files with 4,579 additions and 2,733 deletions.
8 changes: 8 additions & 0 deletions embedded/sql/dummy_data_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ type dummyDataSource struct {
AliasFunc func() string
}

func (d *dummyDataSource) readOnly() bool {
return true
}

func (d *dummyDataSource) requiredPrivileges() []SQLPrivilege {
return []SQLPrivilege{SQLPrivilegeSelect}
}

func (d *dummyDataSource) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
return tx, nil
}
Expand Down
58 changes: 57 additions & 1 deletion embedded/sql/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ var (
ErrColumnMismatchInUnionStmt = errors.New("column mismatch in union statement")
ErrCannotIndexJson = errors.New("cannot index column of type JSON")
ErrInvalidTxMetadata = errors.New("invalid transaction metadata")
ErrAccessDenied = errors.New("access denied")
)

var MaxKeyLen = 512
Expand Down Expand Up @@ -123,16 +124,20 @@ type MultiDBHandler interface {
ListDatabases(ctx context.Context) ([]string, error)
CreateDatabase(ctx context.Context, db string, ifNotExists bool) error
UseDatabase(ctx context.Context, db string) error
GetLoggedUser(ctx context.Context) (User, error)
ListUsers(ctx context.Context) ([]User, error)
CreateUser(ctx context.Context, username, password string, permission Permission) error
AlterUser(ctx context.Context, username, password string, permission Permission) error
GrantSQLPrivileges(ctx context.Context, database, username string, privileges []SQLPrivilege) error
RevokeSQLPrivileges(ctx context.Context, database, username string, privileges []SQLPrivilege) error
DropUser(ctx context.Context, username string) error
ExecPreparedStmts(ctx context.Context, opts *TxOptions, stmts []SQLStmt, params map[string]interface{}) (ntx *SQLTx, committedTxs []*SQLTx, err error)
}

type User interface {
Username() string
Permission() uint32
Permission() Permission
SQLPrivileges() []SQLPrivilege
}

func NewEngine(st *store.ImmuStore, opts *Options) (*Engine, error) {
Expand Down Expand Up @@ -475,6 +480,13 @@ func (e *Engine) execPreparedStmts(ctx context.Context, tx *SQLTx, stmts []SQLSt
}
}

if e.multidbHandler != nil {
if err := e.checkUserPermissions(ctx, stmt); err != nil {
currTx.Cancel()
return nil, committedTxs, stmts[execStmts:], err
}
}

ntx, err := stmt.execAt(ctx, currTx, nparams)
if err != nil {
currTx.Cancel()
Expand Down Expand Up @@ -517,6 +529,44 @@ func (e *Engine) execPreparedStmts(ctx context.Context, tx *SQLTx, stmts []SQLSt
return currTx, committedTxs, stmts[execStmts:], nil
}

func (e *Engine) checkUserPermissions(ctx context.Context, stmt SQLStmt) error {
user, err := e.multidbHandler.GetLoggedUser(ctx)
if err != nil {
return err
}

if user.Permission() == PermissionAdmin {
return nil
}

if !stmt.readOnly() && user.Permission() == PermissionReadOnly {
return ErrAccessDenied
}

requiredPrivileges := stmt.requiredPrivileges()
if !hasAllPrivileges(user.SQLPrivileges(), requiredPrivileges) {
return fmt.Errorf("%w: statement requires %v privileges", ErrAccessDenied, requiredPrivileges)
}
return nil
}

func hasAllPrivileges(userPrivileges, privileges []SQLPrivilege) bool {
for _, p := range privileges {
has := false
for _, up := range userPrivileges {
if up == p {
has = true
break
}
}

if !has {
return false
}
}
return true
}

func (e *Engine) queryAll(ctx context.Context, tx *SQLTx, sql string, params map[string]interface{}) ([]*Row, error) {
reader, err := e.Query(ctx, tx, sql, params)
if err != nil {
Expand Down Expand Up @@ -568,6 +618,12 @@ func (e *Engine) QueryPreparedStmt(ctx context.Context, tx *SQLTx, stmt DataSour
return nil, err
}

if e.multidbHandler != nil {
if err := e.checkUserPermissions(ctx, stmt); err != nil {
return nil, err
}
}

_, err = stmt.execAt(ctx, qtx, nparams)
if err != nil {
return nil, err
Expand Down
144 changes: 132 additions & 12 deletions embedded/sql/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7089,7 +7089,12 @@ func TestMultiDBCatalogQueries(t *testing.T) {
defer closeStore(t, st)

dbs := []string{"db1", "db2"}
handler := &multidbHandlerMock{}
handler := &multidbHandlerMock{
user: &mockUser{
username: "user",
sqlPrivileges: allPrivileges,
},
}

opts := DefaultOptions().
WithPrefix(sqlPrefix).
Expand Down Expand Up @@ -7130,6 +7135,13 @@ func TestMultiDBCatalogQueries(t *testing.T) {
`, nil)
require.ErrorIs(t, err, ErrNonTransactionalStmt)

_, _, err = engine.Exec(context.Background(), nil, `
BEGIN TRANSACTION;
GRANT ALL PRIVILEGES ON DATABASE defaultdb TO USER myuser;
COMMIT;
`, nil)
require.ErrorIs(t, err, ErrNonTransactionalStmt)

_, _, err = engine.Exec(context.Background(), nil, "CREATE DATABASE db1", nil)
require.ErrorIs(t, err, ErrNoSupported)

Expand Down Expand Up @@ -7185,23 +7197,19 @@ func TestMultiDBCatalogQueries(t *testing.T) {
})

t.Run("show users", func(t *testing.T) {
r, err := engine.Query(context.Background(), nil, "SHOW USERS", nil)
rows, err := engine.queryAll(context.Background(), nil, "SHOW USERS", nil)
require.NoError(t, err)
require.Len(t, rows, 1)

defer r.Close()

_, err = r.Read(context.Background())
require.ErrorIs(t, err, ErrNoMoreRows)
require.Equal(t, "user", rows[0].ValuesByPosition[0].RawValue())
})

t.Run("list users", func(t *testing.T) {
r, err := engine.Query(context.Background(), nil, "SELECT * FROM USERS()", nil)
rows, err := engine.queryAll(context.Background(), nil, "SELECT * FROM USERS()", nil)
require.NoError(t, err)
require.Len(t, rows, 1)

defer r.Close()

_, err = r.Read(context.Background())
require.ErrorIs(t, err, ErrNoMoreRows)
require.Equal(t, "user", rows[0].ValuesByPosition[0].RawValue())
})

t.Run("query databases using conditions with table and column aliasing", func(t *testing.T) {
Expand All @@ -7225,8 +7233,27 @@ func TestMultiDBCatalogQueries(t *testing.T) {
})
}

type mockUser struct {
username string
permission Permission
sqlPrivileges []SQLPrivilege
}

func (u *mockUser) Username() string {
return u.username
}

func (u *mockUser) Permission() Permission {
return u.permission
}

func (u *mockUser) SQLPrivileges() []SQLPrivilege {
return u.sqlPrivileges
}

type multidbHandlerMock struct {
dbs []string
user *mockUser
engine *Engine
}

Expand All @@ -7238,12 +7265,27 @@ func (h *multidbHandlerMock) CreateDatabase(ctx context.Context, db string, ifNo
return ErrNoSupported
}

func (h *multidbHandlerMock) GrantSQLPrivileges(ctx context.Context, database, username string, privileges []SQLPrivilege) error {
return ErrNoSupported
}

func (h *multidbHandlerMock) RevokeSQLPrivileges(ctx context.Context, database, username string, privileges []SQLPrivilege) error {
return ErrNoSupported
}

func (h *multidbHandlerMock) UseDatabase(ctx context.Context, db string) error {
return nil
}

func (h *multidbHandlerMock) GetLoggedUser(ctx context.Context) (User, error) {
if h.user == nil {
return nil, fmt.Errorf("no logged user")
}
return h.user, nil
}

func (h *multidbHandlerMock) ListUsers(ctx context.Context) ([]User, error) {
return nil, nil
return []User{h.user}, nil
}

func (h *multidbHandlerMock) CreateUser(ctx context.Context, username, password string, permission Permission) error {
Expand Down Expand Up @@ -8728,3 +8770,81 @@ func TestQueryTxMetadata(t *testing.T) {
)
require.ErrorIs(t, err, ErrInvalidTxMetadata)
}

func TestGrantSQLPrivileges(t *testing.T) {
st, err := store.Open(t.TempDir(), store.DefaultOptions().WithMultiIndexing(true))
require.NoError(t, err)
defer closeStore(t, st)

dbs := []string{"db1", "db2"}
handler := &multidbHandlerMock{
dbs: dbs,
user: &mockUser{
username: "myuser",
permission: PermissionReadOnly,
sqlPrivileges: []SQLPrivilege{SQLPrivilegeSelect},
},
}

opts := DefaultOptions().
WithPrefix(sqlPrefix).
WithMultiDBHandler(handler)

engine, err := NewEngine(st, opts)
require.NoError(t, err)

handler.dbs = dbs
handler.engine = engine

tx, err := engine.NewTx(context.Background(), DefaultTxOptions())
require.NoError(t, err)

_, _, err = engine.Exec(
context.Background(),
tx,
"CREATE TABLE mytable(id INTEGER, PRIMARY KEY id);",
nil,
)
require.ErrorIs(t, err, ErrAccessDenied)

handler.user.sqlPrivileges =
append(handler.user.sqlPrivileges, SQLPrivilegeCreate)

_, _, err = engine.Exec(
context.Background(),
tx,
"CREATE TABLE mytable(id INTEGER, PRIMARY KEY id);",
nil,
)
require.ErrorIs(t, err, ErrAccessDenied)

handler.user.permission = PermissionReadWrite

_, _, err = engine.Exec(
context.Background(),
tx,
"CREATE TABLE mytable(id INTEGER, PRIMARY KEY id);",
nil,
)
require.NoError(t, err)

checkGrants := func(sql string) {
rows, err := engine.queryAll(context.Background(), nil, sql, nil)
require.NoError(t, err)
require.Len(t, rows, 2)

usr := rows[0].ValuesByPosition[0].RawValue().(string)
privilege := rows[0].ValuesByPosition[1].RawValue().(string)

require.Equal(t, usr, "myuser")
require.Equal(t, privilege, string(SQLPrivilegeSelect))

usr = rows[1].ValuesByPosition[0].RawValue().(string)
privilege = rows[1].ValuesByPosition[1].RawValue().(string)
require.Equal(t, usr, "myuser")
require.Equal(t, privilege, string(SQLPrivilegeCreate))
}

checkGrants("SHOW GRANTS")
checkGrants("SHOW GRANTS FOR myuser")
}
5 changes: 5 additions & 0 deletions embedded/sql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ var reservedWords = map[string]int{
"READ": READ,
"READWRITE": READWRITE,
"ADMIN": ADMIN,
"GRANT": GRANT,
"REVOKE": REVOKE,
"GRANTS": GRANTS,
"FOR": FOR,
"PRIVILEGES": PRIVILEGES,
"CHECK": CHECK,
"CONSTRAINT": CONSTRAINT,
}
Expand Down
63 changes: 63 additions & 0 deletions embedded/sql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,69 @@ func TestFloatCornerCases(t *testing.T) {
}
}

func TestGrantRevokeStmt(t *testing.T) {
type test struct {
text string
expectedStmt SQLStmt
}

cases := []test{
{
text: "GRANT SELECT, INSERT, UPDATE, DELETE ON DATABASE defaultdb TO USER immudb",
expectedStmt: &AlterPrivilegesStmt{
database: "defaultdb",
user: "immudb",
privileges: []SQLPrivilege{
SQLPrivilegeDelete,
SQLPrivilegeUpdate,
SQLPrivilegeInsert,
SQLPrivilegeSelect,
},
isGrant: true,
},
},
{
text: "REVOKE SELECT, INSERT, UPDATE, DELETE ON DATABASE defaultdb TO USER immudb",
expectedStmt: &AlterPrivilegesStmt{
database: "defaultdb",
user: "immudb",
privileges: []SQLPrivilege{
SQLPrivilegeDelete,
SQLPrivilegeUpdate,
SQLPrivilegeInsert,
SQLPrivilegeSelect,
},
},
},
{
text: "GRANT ALL PRIVILEGES ON DATABASE defaultdb TO USER immudb",
expectedStmt: &AlterPrivilegesStmt{
database: "defaultdb",
user: "immudb",
privileges: allPrivileges,
isGrant: true,
},
},
{
text: "REVOKE ALL PRIVILEGES ON DATABASE defaultdb TO USER immudb",
expectedStmt: &AlterPrivilegesStmt{
database: "defaultdb",
user: "immudb",
privileges: allPrivileges,
},
},
}

for i, tc := range cases {
t.Run(fmt.Sprintf("alter_privileges_%d", i), func(t *testing.T) {
stmts, err := ParseSQLString(tc.text)
require.NoError(t, err)
require.Len(t, stmts, 1)
require.Equal(t, tc.expectedStmt, stmts[0])
})
}
}

func TestExprString(t *testing.T) {
exps := []string{
"(1 + 1) / (2 * 5 - 10)",
Expand Down
Loading

0 comments on commit c189283

Please sign in to comment.