diff --git a/drizzle-orm/src/aws-data-api/pg/session.ts b/drizzle-orm/src/aws-data-api/pg/session.ts index 53519e6b9..4dee9dfc5 100644 --- a/drizzle-orm/src/aws-data-api/pg/session.ts +++ b/drizzle-orm/src/aws-data-api/pg/session.ts @@ -183,6 +183,8 @@ export class AwsDataApiSession< await tx.setTransaction(config); } try { + await tx.executeRLSConfig(config); + const result = await transaction(tx); await this.client.send(new CommitTransactionCommand({ ...this.rawQuery, transactionId })); return result; diff --git a/drizzle-orm/src/neon-serverless/session.ts b/drizzle-orm/src/neon-serverless/session.ts index 891835259..b53583f5d 100644 --- a/drizzle-orm/src/neon-serverless/session.ts +++ b/drizzle-orm/src/neon-serverless/session.ts @@ -136,6 +136,8 @@ export class NeonSession< const tx = new NeonTransaction(this.dialect, session, this.schema); await tx.execute(sql`begin ${tx.getTransactionConfigSQL(config)}`); try { + await tx.executeRLSConfig(config); + const result = await transaction(tx); await tx.execute(sql`commit`); return result; diff --git a/drizzle-orm/src/node-postgres/session.ts b/drizzle-orm/src/node-postgres/session.ts index 01ad7cac8..2b372aa5d 100644 --- a/drizzle-orm/src/node-postgres/session.ts +++ b/drizzle-orm/src/node-postgres/session.ts @@ -135,6 +135,8 @@ export class NodePgSession< const tx = new NodePgTransaction(this.dialect, session, this.schema); await tx.execute(sql`begin${config ? sql` ${tx.getTransactionConfigSQL(config)}` : undefined}`); try { + await tx.executeRLSConfig(config); + const result = await transaction(tx); await tx.execute(sql`commit`); return result; diff --git a/drizzle-orm/src/pg-core/policy.ts b/drizzle-orm/src/pg-core/policy.ts new file mode 100644 index 000000000..d6ce75389 --- /dev/null +++ b/drizzle-orm/src/pg-core/policy.ts @@ -0,0 +1,53 @@ +import { entityKind } from '~/entity.ts'; +import type { AnyPgTable } from '~/pg-core/table.ts'; +import { Policy } from '~/policy.ts'; +import type { SQL } from '~/sql/sql.ts'; +import type { AnyPgRole } from './role.ts'; + +export type PgPolicyFor = 'select' | 'insert' | 'update' | 'delete' | 'all'; + +export type PgPolicyTo = 'public' | 'current_role' | 'current_user' | 'session_user'; + +export type PgPolicyAs = 'permissive' | 'restrictive'; + +export type PgPolicyConfig = { + as?: PgPolicyAs; + for?: PgPolicyFor; + to?: (AnyPgRole | PgPolicyTo)[]; + using?: SQL; + withCheck?: SQL; +}; + +export const PolicyTable = Symbol.for('drizzle:PolicyTable'); + +export class PgPolicy< + TName extends string, + TTable extends AnyPgTable, + TConfig extends PgPolicyConfig | undefined, +> extends Policy { + static readonly [entityKind]: string = 'PgPolicy'; + + [PolicyTable]: TTable; + config: TConfig; + + declare readonly _: { + readonly brand: 'PgPolicy'; + readonly name: TName; + readonly table: TTable; + readonly config: TConfig; + }; + + constructor(name: TName, table: TTable, config: TConfig) { + super(name); + this[PolicyTable] = table; + this.config = config; + } +} + +export function pgPolicy( + name: TName, + table: TTable, + config?: TConfig, +): PgPolicy { + return new PgPolicy(name, table, config as TConfig); +} diff --git a/drizzle-orm/src/pg-core/role.ts b/drizzle-orm/src/pg-core/role.ts new file mode 100644 index 000000000..671c771f3 --- /dev/null +++ b/drizzle-orm/src/pg-core/role.ts @@ -0,0 +1,46 @@ +import { entityKind } from '~/entity.ts'; +import { Role } from '~/role.ts'; + +// Since the create role clause only allow one of these, I guess drizzle-kit will have to generate +// alter role clauses if the user defines more than one of these +export type PgRoleConfig = { + superuser?: boolean; + createDb?: boolean; + createRole?: boolean; + inherit?: boolean; + login?: boolean; + replication?: boolean; + bypassRLS?: boolean; + connectionLimit?: number; + password?: string; + validUntil?: Date; + inRole?: string; + role?: string[]; // Should this be a PgRole[]? + admin?: string[]; // Should this be a PgRole[]? +}; + +export type AnyPgRole = PgRole; + +export class PgRole extends Role { + static readonly [entityKind]: string = 'PgRole'; + + declare readonly _: { + readonly brand: 'PgRole'; + readonly name: TName; + readonly config: TConfig; + }; + + config: TConfig; + + constructor(name: TName, config: TConfig) { + super(name); + this.config = config; + } +} + +export function pgRole( + name: TName, + config?: TConfig, +): PgRole { + return new PgRole(name, config ?? {} as TConfig); +} diff --git a/drizzle-orm/src/pg-core/session.ts b/drizzle-orm/src/pg-core/session.ts index 2b65567c6..309b7ee27 100644 --- a/drizzle-orm/src/pg-core/session.ts +++ b/drizzle-orm/src/pg-core/session.ts @@ -6,6 +6,7 @@ import { tracer } from '~/tracing.ts'; import { PgDatabase } from './db.ts'; import type { PgDialect } from './dialect.ts'; import type { SelectedFieldsOrdered } from './query-builders/select.types.ts'; +import type { AnyPgRole } from './role.ts'; export interface PreparedQueryConfig { execute: unknown; @@ -29,6 +30,14 @@ export interface PgTransactionConfig { isolationLevel?: 'read uncommitted' | 'read committed' | 'repeatable read' | 'serializable'; accessMode?: 'read only' | 'read write'; deferrable?: boolean; + rlsConfig?: { + set?: { + name: string; + value: string; + isLocal?: boolean; + }[]; + role?: AnyPgRole; + }; } export abstract class PgSession< @@ -114,8 +123,25 @@ export abstract class PgTransaction< return sql.raw(chunks.join(' ')); } - setTransaction(config: PgTransactionConfig): Promise { - return this.session.execute(sql`set transaction ${this.getTransactionConfigSQL(config)}`); + setTransaction(config: PgTransactionConfig): Promise | void { + if (config.accessMode || config.deferrable || config.isolationLevel) { + return this.session.execute(sql`set transaction ${this.getTransactionConfigSQL(config)}`); + } + } + + async executeRLSConfig(config: PgTransactionConfig | undefined): Promise { + const rlsConfig = config?.rlsConfig; + if (rlsConfig) { + if (rlsConfig.set) { + for (const { name, value, isLocal } of rlsConfig.set) { + await this.session.execute(sql`select set_config(${name}, ${value}, ${isLocal === false ? false : true})`); + } + } + + if (rlsConfig.role) { + await this.session.execute(sql`set local role ${rlsConfig.role}`); + } + } } abstract override transaction( diff --git a/drizzle-orm/src/pg-core/table.ts b/drizzle-orm/src/pg-core/table.ts index 0bf686cca..cb42d3f6f 100644 --- a/drizzle-orm/src/pg-core/table.ts +++ b/drizzle-orm/src/pg-core/table.ts @@ -22,6 +22,9 @@ export type TableConfig = TableConfigBase; /** @internal */ export const InlineForeignKeys = Symbol.for('drizzle:PgInlineForeignKeys'); +/** @internal */ +export const RLSEnabled = Symbol.for('drizzle:RLSEnabled'); + export class PgTable extends Table { static readonly [entityKind]: string = 'PgTable'; @@ -33,9 +36,16 @@ export class PgTable extends Table { /**@internal */ [InlineForeignKeys]: ForeignKey[] = []; + [RLSEnabled]: boolean = false; + /** @internal */ override [Table.Symbol.ExtraConfigBuilder]: ((self: Record) => PgTableExtraConfig) | undefined = undefined; + + enableRLS(): this { + this[RLSEnabled] = true; + return this; + } } export type AnyPgTable = {}> = PgTable>; diff --git a/drizzle-orm/src/policy.ts b/drizzle-orm/src/policy.ts new file mode 100644 index 000000000..cb402115d --- /dev/null +++ b/drizzle-orm/src/policy.ts @@ -0,0 +1,24 @@ +import { entityKind } from '~/entity.ts'; +import { SQL, type SQLWrapper } from '~/sql/sql.ts'; + +export const PolicyName = Symbol.for('drizzle:PolicyName'); + +export class Policy< + TName extends string, +> implements SQLWrapper { + static readonly [entityKind]: string = 'Policy'; + + [PolicyName]: TName; + + declare readonly _: { + readonly name: TName; + }; + + constructor(readonly name: TName) { + this[PolicyName] = name; + } + + getSQL(): SQL { + return new SQL([this]); + } +} diff --git a/drizzle-orm/src/postgres-js/session.ts b/drizzle-orm/src/postgres-js/session.ts index f861b260f..ac40939ba 100644 --- a/drizzle-orm/src/postgres-js/session.ts +++ b/drizzle-orm/src/postgres-js/session.ts @@ -139,7 +139,10 @@ export class PostgresJsSession< const tx = new PostgresJsTransaction(this.dialect, session, this.schema); if (config) { await tx.setTransaction(config); + + await tx.executeRLSConfig(config); } + return transaction(tx); }) as Promise; } diff --git a/drizzle-orm/src/role.ts b/drizzle-orm/src/role.ts new file mode 100644 index 000000000..84d53756b --- /dev/null +++ b/drizzle-orm/src/role.ts @@ -0,0 +1,24 @@ +import { entityKind } from '~/entity.ts'; +import { SQL, type SQLWrapper } from '~/sql/sql.ts'; + +export type AnyRole = Role; + +export const RoleName = Symbol.for('drizzle:RoleName'); + +export class Role implements SQLWrapper { + static readonly [entityKind]: string = 'Role'; + + declare readonly _: { + readonly name: TName; + }; + + [RoleName]: TName; + + constructor(readonly name: TName) { + this[RoleName] = name; + } + + getSQL(): SQL { + return new SQL([this]); + } +} diff --git a/drizzle-orm/src/sql/sql.ts b/drizzle-orm/src/sql/sql.ts index cad140d38..0827f4df2 100644 --- a/drizzle-orm/src/sql/sql.ts +++ b/drizzle-orm/src/sql/sql.ts @@ -1,12 +1,14 @@ import { entityKind, is } from '~/entity.ts'; +import type { SelectedFields } from '~/operations.ts'; +import { Policy } from '~/policy.ts'; import { Relation } from '~/relations.ts'; +import { Role } from '~/role.ts'; import { Subquery, SubqueryConfig } from '~/subquery.ts'; import { tracer } from '~/tracing.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import type { AnyColumn } from '../column.ts'; import { Column } from '../column.ts'; import { Table } from '../table.ts'; -import type { SelectedFields } from '~/operations.ts'; /** * This class is used to indicate a primitive param value that is used in `sql` tag. @@ -149,6 +151,10 @@ export class SQL implements SQLWrapper { return { sql: escapeName(chunk.value), params: [] }; } + if (is(chunk, Role) || is(chunk, Policy)) { + return { sql: escapeName(chunk.name), params: [] }; + } + if (chunk === undefined) { return { sql: '', params: [] }; } diff --git a/drizzle-orm/src/vercel-postgres/session.ts b/drizzle-orm/src/vercel-postgres/session.ts index f85748b46..02dc6465f 100644 --- a/drizzle-orm/src/vercel-postgres/session.ts +++ b/drizzle-orm/src/vercel-postgres/session.ts @@ -144,6 +144,8 @@ export class VercelPgSession< const tx = new VercelPgTransaction(this.dialect, session, this.schema); await tx.execute(sql`begin${config ? sql` ${tx.getTransactionConfigSQL(config)}` : undefined}`); try { + await tx.executeRLSConfig(config); + const result = await transaction(tx); await tx.execute(sql`commit`); return result; diff --git a/integration-tests/tests/awsdatapi.test.ts b/integration-tests/tests/awsdatapi.test.ts index 1f390eb70..b3cc4b52c 100644 --- a/integration-tests/tests/awsdatapi.test.ts +++ b/integration-tests/tests/awsdatapi.test.ts @@ -10,6 +10,7 @@ import type { AwsDataApiPgDatabase } from 'drizzle-orm/aws-data-api/pg'; import { drizzle } from 'drizzle-orm/aws-data-api/pg'; import { migrate } from 'drizzle-orm/aws-data-api/pg/migrator'; import { alias, boolean, integer, jsonb, pgTable, pgTableCreator, serial, text, timestamp } from 'drizzle-orm/pg-core'; +import { pgRole } from 'drizzle-orm/pg-core/role'; dotenv.config(); @@ -835,6 +836,146 @@ test.serial('nested transaction rollback', async (t) => { await db.execute(sql`drop table ${users}`); }); +test.serial('transaction with RLS', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const badResult = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(badResult, [{ id: 1, name: 'admin_all', lastName: 'lastName' }]); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + +test.serial('transaction with RLS with attempt to update another user', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await t.throwsAsync(async () => + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 2)); + }, { + rlsConfig: { + role: adminRole, + }, + }) + ); + + const badResult2 = await db.select().from(users).where(eq(users.id, 2)); + + t.deepEqual(badResult2, [{ id: 2, name: 'user2', lastName: 'lastName2' }]); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + test.serial('select from raw sql', async (t) => { const { db } = t.context; diff --git a/integration-tests/tests/neon-http.test.ts b/integration-tests/tests/neon-http.test.ts index 5112fd3b1..6ffa13746 100644 --- a/integration-tests/tests/neon-http.test.ts +++ b/integration-tests/tests/neon-http.test.ts @@ -48,6 +48,7 @@ import { uuid as pgUuid, varchar, } from 'drizzle-orm/pg-core'; +import { pgRole } from 'drizzle-orm/pg-core/role'; import getPort from 'get-port'; import pg from 'pg'; import { v4 as uuid } from 'uuid'; @@ -2143,6 +2144,146 @@ test.serial.skip('nested transaction rollback', async (t) => { await db.execute(sql`drop table ${users}`); }); +test.serial('transaction with RLS', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const badResult = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(badResult, [{ id: 1, name: 'admin_all', lastName: 'lastName' }]); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + +test.serial('transaction with RLS with attempt to update another user', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await t.throwsAsync(async () => + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 2)); + }, { + rlsConfig: { + role: adminRole, + }, + }) + ); + + const badResult2 = await db.select().from(users).where(eq(users.id, 2)); + + t.deepEqual(badResult2, [{ id: 2, name: 'user2', lastName: 'lastName2' }]); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + test.serial('join subquery with join', async (t) => { const { db } = t.context; diff --git a/integration-tests/tests/pg.test.ts b/integration-tests/tests/pg.test.ts index 3c1b1c9d7..dcdc14e46 100644 --- a/integration-tests/tests/pg.test.ts +++ b/integration-tests/tests/pg.test.ts @@ -50,6 +50,7 @@ import { macaddr, macaddr8, type PgColumn, + pgEnum, pgMaterializedView, pgTable, pgTableCreator, @@ -64,12 +65,12 @@ import { uniqueKeyName, uuid as pgUuid, varchar, - pgEnum, } from 'drizzle-orm/pg-core'; +import { pgRole } from 'drizzle-orm/pg-core/role'; import getPort from 'get-port'; import pg from 'pg'; import { v4 as uuid } from 'uuid'; -import { type Equal, Expect } from './utils.ts'; +import { type Equal, Expect } from './utils'; const { Client } = pg; @@ -2426,6 +2427,146 @@ test.serial('nested transaction rollback', async (t) => { await db.execute(sql`drop table ${users}`); }); +test.serial('transaction with RLS', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const badResult = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(badResult, [{ id: 1, name: 'admin_all', lastName: 'lastName' }]); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + +test.serial('transaction with RLS with attempt to update another user', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await t.throwsAsync(async () => + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 2)); + }, { + rlsConfig: { + role: adminRole, + }, + }) + ); + + const badResult2 = await db.select().from(users).where(eq(users.id, 2)); + + t.deepEqual(badResult2, [{ id: 2, name: 'user2', lastName: 'lastName2' }]); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + test.serial('join subquery with join', async (t) => { const { db } = t.context; diff --git a/integration-tests/tests/postgres.js.test.ts b/integration-tests/tests/postgres.js.test.ts index a1f979f60..80502ec0d 100644 --- a/integration-tests/tests/postgres.js.test.ts +++ b/integration-tests/tests/postgres.js.test.ts @@ -41,13 +41,14 @@ import { uuid as pgUuid, varchar, } from 'drizzle-orm/pg-core'; +import { pgRole } from 'drizzle-orm/pg-core/role'; import type { PostgresJsDatabase } from 'drizzle-orm/postgres-js'; import { drizzle } from 'drizzle-orm/postgres-js'; import { migrate } from 'drizzle-orm/postgres-js/migrator'; import getPort from 'get-port'; import postgres, { type Sql } from 'postgres'; import { v4 as uuid } from 'uuid'; -import { type Equal, Expect } from './utils.ts'; +import { type Equal, Expect } from './utils'; const QUERY_LOGGING = false; @@ -1974,6 +1975,146 @@ test.serial('nested transaction rollback', async (t) => { await db.execute(sql`drop table ${users}`); }); +test.serial('transaction with RLS', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const badResult = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(badResult, [{ id: 1, name: 'admin_all', lastName: 'lastName' }]); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + +test.serial('transaction with RLS with attempt to update another user', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await t.throwsAsync(async () => + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 2)); + }, { + rlsConfig: { + role: adminRole, + }, + }) + ); + + const badResult2 = await db.select().from(users).where(eq(users.id, 2)); + + t.deepEqual(badResult2, [{ id: 2, name: 'user2', lastName: 'lastName2' }]); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + test.serial('join subquery with join', async (t) => { const { db } = t.context; diff --git a/integration-tests/tests/vercel-pg.test.ts b/integration-tests/tests/vercel-pg.test.ts index c51f2bd62..d3f607c3f 100644 --- a/integration-tests/tests/vercel-pg.test.ts +++ b/integration-tests/tests/vercel-pg.test.ts @@ -46,6 +46,7 @@ import { uuid as pgUuid, varchar, } from 'drizzle-orm/pg-core'; +import { pgRole } from 'drizzle-orm/pg-core/role'; import { drizzle, type VercelPgDatabase } from 'drizzle-orm/vercel-postgres'; import { migrate } from 'drizzle-orm/vercel-postgres/migrator'; import getPort from 'get-port'; @@ -2121,6 +2122,146 @@ test.serial('nested transaction rollback', async (t) => { await db.execute(sql`drop table ${users}`); }); +test.serial('transaction with RLS', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const badResult = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(badResult, [{ id: 1, name: 'admin_all', lastName: 'lastName' }]); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + +test.serial('transaction with RLS with attempt to update another user', async (t) => { + const { db } = t.context; + + const users = pgTable('users_transactions', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + lastName: text('lastname'), + }); + + const adminRole = pgRole('admin_all'); + + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); + + await db.execute( + sql`create table ${users} (id serial not null primary key, name text not null, lastname text)`, + ); + await db.execute(sql`CREATE ROLE ${adminRole}`); + + await db.insert(users).values([{ name: 'admin_all', lastName: 'lastName' }, { + name: 'user2', + lastName: 'lastName2', + }]); + + await db.execute(sql`ALTER TABLE ${users} ENABLE ROW LEVEL SECURITY`); + await db.execute( + sql`CREATE POLICY all_view ON ${users} for select USING (true)`, + ); + + await db.execute(sql`GRANT USAGE ON SCHEMA public TO admin_all`); + await db.execute(sql`GRANT all ON ${users} TO admin_all`); + + await db.execute( + sql`CREATE POLICY administrator_all ON ${users} TO admin_all USING (true) WITH CHECK (current_role = name)`, + ); + + await t.throwsAsync(async () => + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 2)); + }, { + rlsConfig: { + role: adminRole, + }, + }) + ); + + const badResult2 = await db.select().from(users).where(eq(users.id, 2)); + + t.deepEqual(badResult2, [{ id: 2, name: 'user2', lastName: 'lastName2' }]); + + await db.transaction(async (tx) => { + await tx.update(users).set({ lastName: 'another' }).where(eq(users.id, 1)); + }, { + rlsConfig: { + role: adminRole, + }, + }); + + const result = await db.select().from(users).where(eq(users.id, 1)); + + t.deepEqual(result, [{ id: 1, name: 'admin_all', lastName: 'another' }]); + + await db.execute(sql`drop owned by admin_all`); + await db.execute(sql`drop policy if exists administrator_all on ${users}`); + await db.execute(sql`drop policy if exists all_view on ${users}`); + await db.execute(sql`drop role if exists admin_all`); + await db.execute(sql`drop table if exists ${users}`); +}); + test.serial('join subquery with join', async (t) => { const { db } = t.context;