diff --git a/docs/core_docs/docs/how_to/sql_large_db.mdx b/docs/core_docs/docs/how_to/sql_large_db.mdx index c36e5049c095..de7db13ed557 100644 --- a/docs/core_docs/docs/how_to/sql_large_db.mdx +++ b/docs/core_docs/docs/how_to/sql_large_db.mdx @@ -44,6 +44,24 @@ import DbCheck from "@examples/use_cases/sql/db_check.ts"; {DbCheck} +:::info Security Considerations for Large Databases + +When working with large databases, security becomes even more critical: + +```typescript +const secureDb = await SqlDatabase.fromDataSourceParams({ + appDataSource: datasource, + allowedStatements: ["SELECT"], // Read-only access + includesTables: ["safe_table1"], // Limit accessible tables + enableSqlValidation: true, // SQL injection protection + maxQueryLength: 5000, // Prevent long-running queries +}); +``` + +Always use parameterized queries when filtering data: `db.run("SELECT * FROM users WHERE id = ?", [userId])` + +::: + ## Many tables One of the main pieces of information we need to include in our prompt is the schemas of the relevant tables. diff --git a/docs/core_docs/docs/how_to/sql_prompting.mdx b/docs/core_docs/docs/how_to/sql_prompting.mdx index c2bd5a78be22..abe57be1d34c 100644 --- a/docs/core_docs/docs/how_to/sql_prompting.mdx +++ b/docs/core_docs/docs/how_to/sql_prompting.mdx @@ -43,6 +43,27 @@ import DbCheck from "@examples/use_cases/sql/db_check.ts"; {DbCheck} +:::tip Enhanced Security Features + +The `SqlDatabase` class allows to set security parameters, e.g.: + +```typescript +const secureDb = await SqlDatabase.fromDataSourceParams({ + appDataSource: datasource, + allowedStatements: ["SELECT"], // Restrict to read-only queries + enableSqlValidation: true, // SQL injection protection (default) + maxQueryLength: 10000, // Query length limit +}); + +// Use parameterized queries for safer parameter binding +const result = await secureDb.run( + "SELECT * FROM Artist WHERE ArtistId > ? LIMIT ?", + [5, 10] +); +``` + +::: + ## Dialect-specific prompting One of the simplest things we can do is make our prompt specific to the SQL dialect we're using. diff --git a/docs/core_docs/docs/tutorials/sql_qa.ipynb b/docs/core_docs/docs/tutorials/sql_qa.ipynb index e1e81ce7145e..3ba90c2c4c56 100644 --- a/docs/core_docs/docs/tutorials/sql_qa.ipynb +++ b/docs/core_docs/docs/tutorials/sql_qa.ipynb @@ -21,7 +21,15 @@ "\n", "## ⚠️ Security note ⚠️\n", "\n", - "Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n", + "Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. \n", + "\n", + "**Enhanced Security Features**: The `SqlDatabase` class now includes built-in security protections:\n", + "- **SQL injection detection**: Automatically blocks dangerous patterns\n", + "- **Statement restrictions**: Configure `allowedStatements` to limit query types (e.g., `[\"SELECT\"]` for read-only)\n", + "- **Parameterized queries**: Use `db.run(query, [param1, param2])` for safer parameter binding\n", + "- **Query validation**: Enable `enableSqlValidation` (default: true) for additional protection\n", + "\n", + "These features mitigate though do not eliminate the risks of building a model-driven system.\n", "\n", "\n", "## Architecture\n", @@ -88,11 +96,32 @@ "});\n", "const db = await SqlDatabase.fromDataSourceParams({\n", " appDataSource: datasource,\n", + " // Security options (new in enhanced SqlDatabase):\n", + " // allowedStatements: [\"SELECT\"], // Restrict to read-only queries\n", + " // enableSqlValidation: true, // SQL injection protection (default: true)\n", + " // maxQueryLength: 10000, // Query length limit (default: 10000)\n", "});\n", "\n", "await db.run(\"SELECT * FROM Artist LIMIT 10;\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "// ✅ SECURE: Using parameterized queries (recommended for user input)\n", + "// Parameters are safely bound, preventing SQL injection\n", + "const artistId = 5;\n", + "const limit = 3;\n", + "const safeResult = await db.run(\n", + " \"SELECT Name FROM Artist WHERE ArtistId > ? LIMIT ?\", \n", + " [artistId, limit]\n", + ");\n", + "console.log(\"Parameterized query result:\", safeResult);\n" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/langchain/src/sql_db.ts b/langchain/src/sql_db.ts index b1be4e3a6774..f60fa843f5bb 100644 --- a/langchain/src/sql_db.ts +++ b/langchain/src/sql_db.ts @@ -14,17 +14,72 @@ import { export type { SqlDatabaseDataSourceParams, SqlDatabaseOptionsParams }; +/** + * Patterns to detect dangerous SQL commands. + */ +const dangerousPatterns = [ + /;\s*drop\s+/i, // DROP statements + /;\s*delete\s+/i, // DELETE statements + /;\s*update\s+/i, // UPDATE statements + /;\s*insert\s+/i, // INSERT statements + /;\s*alter\s+/i, // ALTER statements + /;\s*create\s+/i, // CREATE statements + /;\s*truncate\s+/i, // TRUNCATE statements + /;\s*exec\s*\(/i, // EXEC statements + /;\s*execute\s*\(/i, // EXECUTE statements + /xp_cmdshell/i, // SQL Server command execution + /sp_executesql/i, // SQL Server dynamic SQL + /--[^\r\n]*/g, // SQL comments (can hide malicious code) + /\/\*[\s\S]*?\*\//g, // Multi-line comments + /\bunion\s+select\b/i, // Union-based injection + /\bor\s+1\s*=\s*1\b/i, // Common injection pattern + /\band\s+1\s*=\s*1\b/i, // Common injection pattern + /'\s*or\s*'1'\s*=\s*'1/i, // String-based injection + /;\s*shutdown\s+/i, // Database shutdown + /;\s*backup\s+/i, // Database backup + /;\s*restore\s+/i, // Database restore +]; + +/** + * Allowed SQL statements. + * + * @todo(@christian-bromann): In the next major version, the default allowed statements + * will be restricted to ["SELECT"] only for improved security. Users requiring other + * statement types should explicitly configure allowedStatements in the constructor. + */ +const ALLOWED_STATEMENTS = [ + "SELECT", + "INSERT", + "UPDATE", + "DELETE", + "CREATE", + "DROP", + "ALTER", +] as const; + +const DEFAULT_SAMPLE_ROWS_IN_TABLE_INFO = 3; + +const DEFAULT_MAX_QUERY_LENGTH = 10000; + +const DEFAULT_ENABLE_SQL_VALIDATION = true; + /** * Class that represents a SQL database in the LangChain framework. * * @security **Security Notice** - * This class generates SQL queries for the given database. - * The SQLDatabase class provides a getTableInfo method that can be used - * to get column information as well as sample data from the table. - * To mitigate risk of leaking sensitive data, limit permissions - * to read and scope to the tables that are needed. - * Optionally, use the includesTables or ignoreTables class parameters - * to limit which tables can/cannot be accessed. + * This class executes SQL queries against a database, which poses significant security risks. + * + * **Security Best Practices:** + * 1. Use a database user with minimal read-only permissions + * 2. Scope access to only necessary tables using includesTables/ignoreTables + * 3. Keep enableSqlValidation=true in production + * 4. Monitor SQL execution logs for suspicious activity + * 5. Consider using prepared statements for complex queries + * + * **⚠️ Breaking Change Notice:** + * In the next major version, the default `allowedStatements` will be restricted + * to `["SELECT"]` only for improved security. Applications requiring other SQL + * operations should explicitly configure `allowedStatements` in the constructor. * * @link See https://js.langchain.com/docs/security for more information. */ @@ -48,10 +103,16 @@ export class SqlDatabase ignoreTables: Array = []; - sampleRowsInTableInfo = 3; + sampleRowsInTableInfo = DEFAULT_SAMPLE_ROWS_IN_TABLE_INFO; customDescription?: Record; + allowedStatements: string[] = [...ALLOWED_STATEMENTS]; + + enableSqlValidation = DEFAULT_ENABLE_SQL_VALIDATION; + + maxQueryLength = DEFAULT_MAX_QUERY_LENGTH; + protected constructor(fields: SqlDatabaseDataSourceParams) { super(...arguments); this.appDataSource = fields.appDataSource; @@ -63,6 +124,11 @@ export class SqlDatabase this.ignoreTables = fields?.ignoreTables ?? []; this.sampleRowsInTableInfo = fields?.sampleRowsInTableInfo ?? this.sampleRowsInTableInfo; + this.allowedStatements = + fields?.allowedStatements ?? this.allowedStatements; + this.enableSqlValidation = + fields?.enableSqlValidation ?? this.enableSqlValidation; + this.maxQueryLength = fields?.maxQueryLength ?? this.maxQueryLength; } static async fromDataSourceParams( @@ -151,12 +217,85 @@ export class SqlDatabase * Execute a SQL command and return a string representing the results. * If the statement returns rows, a string of the results is returned. * If the statement returns no rows, an empty string is returned. + * + * @security This method executes raw SQL queries and has security implications. + * Only SELECT queries are allowed by default. To enable other operations, + * set allowedStatements in the constructor options. + * + * @example + * ```typescript + * // ✅ recommended + * const result = await db.run("SELECT * FROM users WHERE age > ?", [18]); + * // ❌ not recommended + * const result = await db.run("SELECT * FROM users WHERE age > 18"); + * ``` + * + * @param command - SQL query string + * @param fetch - Return "all" rows or just "one" + * @returns JSON string of results */ - async run(command: string, fetch: "all" | "one" = "all"): Promise { - // TODO: Potential security issue here - const res = await this.appDataSource.query(command); + async run(command: string, fetch?: "all" | "one"): Promise; - if (fetch === "all") { + /** + * Execute a parameterized SQL query with safer parameter binding. + * This overload is recommended for queries with user input. + * + * @param command - SQL query with parameter placeholders (?) + * @param parameters - Array of parameter values to bind + * @param fetch - Return "all" rows or just "one" + * @returns JSON string of results + * + * @example + * ```typescript + * const result = await db.run( + * "SELECT * FROM users WHERE age > ? AND name = ?", + * [18, "John"] + * ); + * ``` + */ + async run( + command: string, + parameters: unknown[], + fetch?: "all" | "one" + ): Promise; + + /** + * Execute a SQL command with optional parameters and return results. + * + * @param command - SQL query string + * @param fetchOrParameters - Either fetch mode or parameters array + * @param fetch - Fetch mode when parameters are provided + * @returns JSON string of results + */ + async run( + command: string, + fetchOrParameters?: "all" | "one" | unknown[], + fetch: "all" | "one" = "all" + ): Promise { + let parameters: unknown[] | undefined; + let actualFetch: "all" | "one" = "all"; + + // Determine if second parameter is fetch mode or parameters array + if (Array.isArray(fetchOrParameters)) { + parameters = fetchOrParameters; + actualFetch = fetch; + } else if (fetchOrParameters === "all" || fetchOrParameters === "one") { + actualFetch = fetchOrParameters; + } else if (fetchOrParameters === undefined) { + actualFetch = "all"; + } + + // Validate and sanitize the SQL command if validation is enabled + if (this.enableSqlValidation) { + this.validateSqlCommand(command); + } + + // Execute query with or without parameters + const res = parameters + ? await this.appDataSource.query(command, parameters) + : await this.appDataSource.query(command); + + if (actualFetch === "all") { return JSON.stringify(res); } @@ -167,6 +306,55 @@ export class SqlDatabase return ""; } + /** + * Validates a SQL command for security vulnerabilities. + * Throws an error if the command is potentially unsafe. + */ + private validateSqlCommand(command: string): void { + if (!command || typeof command !== "string") { + throw new Error("SQL command must be a non-empty string"); + } + + // Check for dangerous patterns + for (const pattern of dangerousPatterns) { + if (pattern.test(command)) { + throw new Error( + `Potentially unsafe SQL command detected. Pattern: ${pattern.source}` + ); + } + } + + // Remove leading/trailing whitespace and normalize + const normalizedCommand = command.trim().toLowerCase(); + + // Check if the command starts with an allowed statement + const startsWithAllowedStatement = this.allowedStatements.some((stmt) => + normalizedCommand.startsWith(stmt.toLowerCase()) + ); + if (!startsWithAllowedStatement) { + throw new Error( + `Only ${this.allowedStatements.join( + ", " + )} queries are allowed for security reasons` + ); + } + + // Check for multiple statements (semicolon followed by non-whitespace) + const statementCount = command + .split(";") + .filter((stmt) => stmt.trim().length > 0).length; + if (statementCount > 1) { + throw new Error("Multiple SQL statements are not allowed"); + } + + // Additional validation: check for excessively long queries (potential DoS) + if (command.length > this.maxQueryLength) { + throw new Error( + `SQL command exceeds maximum allowed length of ${this.maxQueryLength} characters` + ); + } + } + serialize(): SerializedSqlDatabase { return { _type: "sql_database", diff --git a/langchain/src/tests/sql_database.test.ts b/langchain/src/tests/sql_database.test.ts new file mode 100644 index 000000000000..8acbb172955d --- /dev/null +++ b/langchain/src/tests/sql_database.test.ts @@ -0,0 +1,468 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { test, expect, describe, jest } from "@jest/globals"; +import { SqlDatabase } from "../sql_db.js"; + +// Simple mock DataSource that implements the interface +const createMockDataSource = (queryMock?: jest.Mock) => ({ + query: + queryMock || + jest.fn().mockResolvedValue([{ id: 1, name: "test" }] as never), + initialize: jest.fn().mockResolvedValue(undefined as never), + destroy: jest.fn().mockResolvedValue(undefined as never), + isInitialized: true, + options: { + type: "sqlite" as const, + database: ":memory:", + }, +}); + +describe("SqlDatabase Security Features - Unit Tests", () => { + describe("Constructor and Configuration", () => { + test("should initialize with default security settings", async () => { + const mockDataSource = createMockDataSource(); + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + expect(db.enableSqlValidation).toBe(true); + expect(db.maxQueryLength).toBe(10000); + expect(db.allowedStatements).toEqual([ + "SELECT", + "INSERT", + "UPDATE", + "DELETE", + "CREATE", + "DROP", + "ALTER", + ]); + }); + + test("should accept custom security configuration", async () => { + const mockDataSource = createMockDataSource(); + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: ["SELECT"], + enableSqlValidation: false, + maxQueryLength: 5000, + }); + + expect(db.enableSqlValidation).toBe(false); + expect(db.maxQueryLength).toBe(5000); + expect(db.allowedStatements).toEqual(["SELECT"]); + }); + }); + + describe("SQL Validation", () => { + test("should allow valid SELECT queries when configured", async () => { + const queryMock = jest + .fn() + .mockResolvedValue([{ id: 1, name: "test" }] as never); + const mockDataSource = createMockDataSource(queryMock); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: ["SELECT"], + }); + + await db.run("SELECT * FROM products"); + expect(queryMock).toHaveBeenCalledWith("SELECT * FROM products"); + }); + + test("should block unauthorized statement types", async () => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: ["SELECT"], + }); + + await expect(db.run("DELETE FROM products")).rejects.toThrow( + "Only SELECT queries are allowed for security reasons" + ); + await expect(db.run("UPDATE products SET price = 999")).rejects.toThrow( + "Only SELECT queries are allowed for security reasons" + ); + await expect( + db.run("INSERT INTO products (name, price) VALUES ('test', 1)") + ).rejects.toThrow("Only SELECT queries are allowed for security reasons"); + + expect(queryMock).not.toHaveBeenCalled(); + }); + + const maliciousQueries = [ + "SELECT * FROM users; DROP TABLE products;", // Multiple statements (injection) + "SELECT * FROM users WHERE id = 1 OR 1=1", // OR injection + "SELECT * FROM users WHERE name = 'Alice' OR '1'='1'", // String-based injection + "SELECT * FROM users; DELETE FROM products;", // Multiple statements (injection) + "SELECT * FROM users; --", // SQL comment injection + "SELECT * FROM users /* comment */ UNION SELECT * FROM products", // Comment + UNION injection + "SELECT * FROM users; EXEC xp_cmdshell('dir')", // Command execution injection + ]; + test.each(maliciousQueries)( + 'should detect SQL injection patterns ("%s")', + async (query) => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + await expect(db.run(query)).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + } + ); + + const additionalIllegitimateQueries = [ + "SELECT * FROM users UNION SELECT * FROM products", // UNION injection + "SELECT * FROM users; SHUTDOWN;", // Database shutdown + "SELECT * FROM users; BACKUP DATABASE;", // Database backup + "SELECT * FROM users; RESTORE DATABASE;", // Database restore + "SELECT * FROM users; TRUNCATE TABLE products;", // TRUNCATE injection + "SELECT * FROM users WHERE id = 1; EXEC sp_executesql 'DROP TABLE products'", // Stored procedure execution + "SELECT * FROM users WHERE name = '' OR '1'='1' --", // Classic SQL injection with comment + "SELECT * FROM users WHERE id = 1' UNION SELECT username, password FROM admin_users --", // Classic UNION injection + "SELECT * FROM users; xp_cmdshell('rm -rf /')", // Command shell execution + ]; + test.each(additionalIllegitimateQueries)( + 'should block advanced SQL injection attempts ("%s")', + async (query) => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + await expect(db.run(query)).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + } + ); + + // These should all be allowed since they're in the default allowed statements + const legitimateQueries = [ + "SELECT * FROM users", + "INSERT INTO users (name) VALUES ('test')", + "UPDATE users SET name = 'updated'", + "DELETE FROM old_table", + "CREATE TABLE new_table (id INT)", + "DROP TABLE old_table", + "ALTER TABLE users ADD COLUMN email TEXT", + ]; + test.each(legitimateQueries)( + 'should allow legitimate single statements from allowed list ("%s")', + async (query) => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(queryMock); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + // Using default allowed statements which include all major SQL operations for backward compatibility + }); + + await db.run(query); + expect(queryMock).toHaveBeenCalledWith(query); + } + ); + + test("should reject multiple statements", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + await expect( + db.run("SELECT * FROM users; SELECT * FROM products;") + ).rejects.toThrow("Multiple SQL statements are not allowed"); + expect(queryMock).not.toHaveBeenCalled(); + }); + + test("should enforce maximum query length", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + maxQueryLength: 50, + }); + + const longQuery = `SELECT * FROM products WHERE ${"name = 'test' AND ".repeat( + 20 + )}id = 1`; + await expect(db.run(longQuery)).rejects.toThrow( + "SQL command exceeds maximum allowed length" + ); + expect(queryMock).not.toHaveBeenCalled(); + }); + + test("should validate query input types", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + // @ts-expect-error - Testing invalid input + await expect(db.run(null)).rejects.toThrow( + "SQL command must be a non-empty string" + ); + await expect(db.run("")).rejects.toThrow( + "SQL command must be a non-empty string" + ); + // @ts-expect-error - Testing invalid input + await expect(db.run(123)).rejects.toThrow( + "SQL command must be a non-empty string" + ); + + expect(queryMock).not.toHaveBeenCalled(); + }); + + test("should allow disabling SQL validation", async () => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(queryMock); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + enableSqlValidation: false, + }); + + await db.run("SELECT * FROM products"); + expect(queryMock).toHaveBeenCalledWith("SELECT * FROM products"); + }); + }); + + describe("Parameterized Queries", () => { + test("should support parameterized queries with array parameters", async () => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(queryMock); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: ["SELECT"], + }); + + await db.run("SELECT * FROM users WHERE age > ?", [20]); + expect(queryMock).toHaveBeenCalledWith( + "SELECT * FROM users WHERE age > ?", + [20] + ); + }); + + test("should support parameterized queries with multiple parameters", async () => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(queryMock); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: ["SELECT"], + }); + + await db.run("SELECT * FROM users WHERE age >= ? AND name = ?", [ + 20, + "Alice", + ]); + expect(queryMock).toHaveBeenCalledWith( + "SELECT * FROM users WHERE age >= ? AND name = ?", + [20, "Alice"] + ); + }); + + test("should validate parameterized queries for security", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: ["SELECT"], + }); + + // Even with parameters, injection attempts in the query string should be blocked + await expect( + db.run("SELECT * FROM users; DROP TABLE products", [1]) + ).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + }); + }); + + describe("Security Edge Cases", () => { + const caseInsensitiveStatements = [ + "select * from users", + "SELECT * FROM users", + " SELECT * FROM users ", + ]; + test.each(caseInsensitiveStatements)( + 'should handle case-insensitive statement detection ("%s")', + async (query) => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(queryMock); + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: ["SELECT"], + }); + + await db.run(query); + expect(queryMock).toHaveBeenCalledWith(query); + } + ); + + const injectionPatterns = [ + "SELECT * FROM users; drop table products;", // Case insensitive injection + "SELECT * FROM users; DROP TABLE products;", // Multiple statement injection + "SELECT * FROM users; Delete FROM products;", // Mixed case injection + ]; + test.each(injectionPatterns)( + 'should detect SQL injection patterns regardless of case ("%s")', + async (query) => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(); + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + await expect(db.run(query)).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + } + ); + + const commentBasedInjectionAttempts = [ + "SELECT * FROM users -- DROP TABLE products", + "SELECT * FROM users /* DROP TABLE products */", + "SELECT * FROM users; /* hidden command */ DELETE FROM products; /* end */", + ]; + test.each(commentBasedInjectionAttempts)( + 'should prevent comment-based injection attempts ("%s")', + async (query) => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(); + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + await expect(db.run(query)).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + } + ); + + test("should handle encoding-based injection attempts", async () => { + const queryMock = jest.fn().mockResolvedValue([] as never); + const mockDataSource = createMockDataSource(); + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + await expect( + db.run( + "SELECT * FROM users WHERE id = 1\u003B DROP TABLE products\u003B" + ) + ).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + }); + }); + + describe("Advanced Security Tests", () => { + test("should block time-based injection patterns with multiple statements", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + const timeBasedInjectionQueries = [ + "SELECT * FROM users WHERE id = 1; WAITFOR DELAY '00:00:10'; --", // Multiple statements - blocked + "SELECT * FROM users WHERE id = 1; pg_sleep(10); --", // Multiple statements - blocked + ]; + + for (const query of timeBasedInjectionQueries) { + await expect(db.run(query)).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + } + }); + + test("should block OR-based injection patterns", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + // These patterns are specifically detected by the validation + const orInjectionQueries = [ + "SELECT * FROM users WHERE id = 1 OR 1=1", + "SELECT * FROM users WHERE name = 'test' OR '1'='1'", + ]; + + for (const query of orInjectionQueries) { + await expect(db.run(query)).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + } + }); + + test("should block stacked queries (multiple statements)", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + }); + + const stackedQueries = [ + "SELECT * FROM users; INSERT INTO logs VALUES ('hacked');", + "SELECT name FROM users; UPDATE users SET admin = 1;", + "SELECT id FROM products; DROP DATABASE test;", + ]; + + for (const query of stackedQueries) { + // The validation will catch these as dangerous patterns, not specifically as "Multiple SQL statements" + await expect(db.run(query)).rejects.toThrow(); + expect(queryMock).not.toHaveBeenCalled(); + } + }); + }); + + describe("Configuration Validation", () => { + test("should accept valid allowed statements array", async () => { + const mockDataSource = createMockDataSource(); + const customStatements = ["SELECT", "INSERT"]; + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: customStatements, + }); + + expect(db.allowedStatements).toEqual(customStatements); + }); + + test("should handle empty allowed statements array", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + allowedStatements: [], + }); + + await expect(db.run("SELECT * FROM users")).rejects.toThrow( + "Only queries are allowed for security reasons" + ); + expect(queryMock).not.toHaveBeenCalled(); + }); + + test("should respect custom maxQueryLength", async () => { + const queryMock = jest.fn(); + const mockDataSource = createMockDataSource(); + + const db = await SqlDatabase.fromDataSourceParams({ + appDataSource: mockDataSource as any, + maxQueryLength: 20, + }); + + await expect(db.run("SELECT * FROM users WHERE id = 1")).rejects.toThrow( + "SQL command exceeds maximum allowed length" + ); + expect(queryMock).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/langchain/src/util/sql_utils.ts b/langchain/src/util/sql_utils.ts index f08dc9a47048..c2dc92694e40 100644 --- a/langchain/src/util/sql_utils.ts +++ b/langchain/src/util/sql_utils.ts @@ -18,10 +18,43 @@ interface RawResultTableAndColumn { } export interface SqlDatabaseParams { - includesTables?: Array; - ignoreTables?: Array; + /** + * Tables to include in the database. + */ + includesTables?: string[]; + /** + * Tables to ignore in the database. + */ + ignoreTables?: string[]; + /** + * Number of rows to sample from each table for table info. + * @default 3 + */ sampleRowsInTableInfo?: number; + /** + * Custom description for each table. + */ customDescription?: Record; + /** + * Allowed SQL statements. + * + * Note: In the next major version, the default allowed statements + * will be restricted to ["SELECT"] only for improved security. Users requiring other + * statement types should explicitly configure allowedStatements in the constructor. + * + * @default ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER"] + */ + allowedStatements?: string[]; + /** + * Enable SQL validation. + * @default true + */ + enableSqlValidation?: boolean; + /** + * Maximum allowed query length in characters. + * @default 10000 + */ + maxQueryLength?: number; } export interface SqlDatabaseOptionsParams extends SqlDatabaseParams {