diff --git a/composer.json b/composer.json index bd8c90f..71ac47d 100644 --- a/composer.json +++ b/composer.json @@ -17,8 +17,8 @@ ], "require": { "php": "^8.1", - "doctrine/dbal": "^3.6", - "illuminate/contracts": "^9.0|^10.0", + "illuminate/contracts": "^10.0", + "illuminate/database": "^10.43", "openai-php/laravel": "^0.3.1", "spatie/laravel-package-tools": "^1.14.0", "spatie/once": "^3.1" diff --git a/resources/views/prompts/query.blade.php b/resources/views/prompts/query.blade.php index 84a5e56..fd49ec6 100644 --- a/resources/views/prompts/query.blade.php +++ b/resources/views/prompts/query.blade.php @@ -9,7 +9,7 @@ Only use the following tables and columns: @foreach($tables as $table) -"{{ $table->getName() }}" has columns: {{ collect($table->getColumns())->map(fn(\Doctrine\DBAL\Schema\Column $column) => $column->getName() . ' ('.$column->getType()->getName().')')->implode(', ') }} +"{{ $table }}" has columns: {{ collect(\Illuminate\Support\Facades\Schema::getColumns($table))->map(fn(array $column) => $column['name'] . ' ('.$column['type_name'].')')->implode(', ') }} @endforeach Question: "{!! $question !!}" diff --git a/resources/views/prompts/tables.blade.php b/resources/views/prompts/tables.blade.php index 251487a..055adf1 100644 --- a/resources/views/prompts/tables.blade.php +++ b/resources/views/prompts/tables.blade.php @@ -1,5 +1,5 @@ Given the below input question and list of potential tables, output a comma separated list of the table names that may be necessary to answer this question. Question: {{ $question }} -Table Names: @foreach($tables as $table){{ $table->getName() }},@endforeach +Table Names: @foreach($tables as $table){{ $table }},@endforeach Relevant Table Names: diff --git a/src/Oracle.php b/src/Oracle.php index 8094fee..4d118d2 100755 --- a/src/Oracle.php +++ b/src/Oracle.php @@ -4,6 +4,7 @@ use BeyondCode\Oracle\Exceptions\PotentiallyUnsafeQuery; use Illuminate\Support\Facades\DB; +use Illuminate\Support\Facades\Schema; use Illuminate\Support\Str; use OpenAI\Client; @@ -104,20 +105,22 @@ protected function ensureQueryIsSafe(string $query): void protected function getDialect(): string { - $databasePlatform = DB::connection($this->connection)->getDoctrineConnection()->getDatabasePlatform(); - - return Str::before(class_basename($databasePlatform), 'Platform'); + $connection = DB::connection($this->connection); + + return match (true) { + $connection instanceof \Illuminate\Database\MySqlConnection && $connection->isMaria() => 'MariaDB', + $connection instanceof \Illuminate\Database\MySqlConnection => 'MySQL', + $connection instanceof \Illuminate\Database\PostgresConnection => 'PostgreSQL', + $connection instanceof \Illuminate\Database\SQLiteConnection => 'SQLite', + $connection instanceof \Illuminate\Database\SqlServerConnection => 'SQL Server', + default => $connection->getDriverName(), + }; } - /** - * @return \Doctrine\DBAL\Schema\Table[] - */ protected function getTables(string $question): array { return once(function () use ($question) { - $tables = DB::connection($this->connection) - ->getDoctrineSchemaManager() - ->listTables(); + $tables = Schema::connection($this->connection)->getTableListing(); if (count($tables) < config('ask-database.max_tables_before_performing_lookup')) { return $tables; @@ -142,7 +145,7 @@ protected function filterMatchingTables(string $question, array $tables): array ->transform(fn (string $tableName) => strtolower(trim($tableName))); return collect($tables)->filter(function ($table) use ($matchingTables) { - return $matchingTables->contains(strtolower($table->getName())); + return $matchingTables->contains(strtolower($table)); })->toArray(); } } diff --git a/tests/Fixtures/filtered-query-prompt.txt b/tests/Fixtures/filtered-query-prompt.txt index 555bd82..d3477d4 100644 --- a/tests/Fixtures/filtered-query-prompt.txt +++ b/tests/Fixtures/filtered-query-prompt.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,7 +8,7 @@ Answer: "Final answer here" Only use the following tables and columns: -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: " diff --git a/tests/Fixtures/filtered-result-prompt.txt b/tests/Fixtures/filtered-result-prompt.txt index 2a73acb..0ee21be 100644 --- a/tests/Fixtures/filtered-result-prompt.txt +++ b/tests/Fixtures/filtered-result-prompt.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,7 +8,7 @@ Answer: "Final answer here" Only use the following tables and columns: -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: "SELECT COUNT(*) FROM users;" diff --git a/tests/Fixtures/query-prompt-l10.txt b/tests/Fixtures/query-prompt-l10.txt index c4eb3c0..4686a39 100644 --- a/tests/Fixtures/query-prompt-l10.txt +++ b/tests/Fixtures/query-prompt-l10.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,10 +8,10 @@ Answer: "Final answer here" Only use the following tables and columns: -"failed_jobs" has columns: id (integer), uuid (string), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) -"migrations" has columns: id (integer), migration (string), batch (integer) -"password_reset_tokens" has columns: email (string), token (string), created_at (datetime) -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"failed_jobs" has columns: id (integer), uuid (varchar), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) +"migrations" has columns: id (integer), migration (varchar), batch (integer) +"password_reset_tokens" has columns: email (varchar), token (varchar), created_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: " diff --git a/tests/Fixtures/query-prompt.txt b/tests/Fixtures/query-prompt.txt index 34d5e8a..4e3ddc8 100644 --- a/tests/Fixtures/query-prompt.txt +++ b/tests/Fixtures/query-prompt.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,10 +8,10 @@ Answer: "Final answer here" Only use the following tables and columns: -"failed_jobs" has columns: id (integer), uuid (string), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) -"migrations" has columns: id (integer), migration (string), batch (integer) -"password_resets" has columns: email (string), token (string), created_at (datetime) -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"failed_jobs" has columns: id (integer), uuid (varchar), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) +"migrations" has columns: id (integer), migration (varchar), batch (integer) +"password_resets" has columns: email (varchar), token (varchar), created_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: " diff --git a/tests/Fixtures/result-prompt-l10.txt b/tests/Fixtures/result-prompt-l10.txt index ebf2aa5..eee5ad9 100644 --- a/tests/Fixtures/result-prompt-l10.txt +++ b/tests/Fixtures/result-prompt-l10.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,10 +8,10 @@ Answer: "Final answer here" Only use the following tables and columns: -"failed_jobs" has columns: id (integer), uuid (string), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) -"migrations" has columns: id (integer), migration (string), batch (integer) -"password_reset_tokens" has columns: email (string), token (string), created_at (datetime) -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"failed_jobs" has columns: id (integer), uuid (varchar), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) +"migrations" has columns: id (integer), migration (varchar), batch (integer) +"password_reset_tokens" has columns: email (varchar), token (varchar), created_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: "SELECT COUNT(*) FROM users;" diff --git a/tests/Fixtures/result-prompt.txt b/tests/Fixtures/result-prompt.txt index 36f2db0..ecf2093 100644 --- a/tests/Fixtures/result-prompt.txt +++ b/tests/Fixtures/result-prompt.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,10 +8,10 @@ Answer: "Final answer here" Only use the following tables and columns: -"failed_jobs" has columns: id (integer), uuid (string), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) -"migrations" has columns: id (integer), migration (string), batch (integer) -"password_resets" has columns: email (string), token (string), created_at (datetime) -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"failed_jobs" has columns: id (integer), uuid (varchar), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) +"migrations" has columns: id (integer), migration (varchar), batch (integer) +"password_resets" has columns: email (varchar), token (varchar), created_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: "SELECT COUNT(*) FROM users;" diff --git a/tests/Fixtures/unsafe-query-prompt-l10.txt b/tests/Fixtures/unsafe-query-prompt-l10.txt index bf53ce2..34b550e 100644 --- a/tests/Fixtures/unsafe-query-prompt-l10.txt +++ b/tests/Fixtures/unsafe-query-prompt-l10.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,10 +8,10 @@ Answer: "Final answer here" Only use the following tables and columns: -"failed_jobs" has columns: id (integer), uuid (string), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) -"migrations" has columns: id (integer), migration (string), batch (integer) -"password_reset_tokens" has columns: email (string), token (string), created_at (datetime) -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"failed_jobs" has columns: id (integer), uuid (varchar), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) +"migrations" has columns: id (integer), migration (varchar), batch (integer) +"password_reset_tokens" has columns: email (varchar), token (varchar), created_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: "DROP TABLE users;" diff --git a/tests/Fixtures/unsafe-query-prompt.txt b/tests/Fixtures/unsafe-query-prompt.txt index b4797a9..1ca5d8c 100644 --- a/tests/Fixtures/unsafe-query-prompt.txt +++ b/tests/Fixtures/unsafe-query-prompt.txt @@ -1,4 +1,4 @@ -Given an input question, first create a syntactically correct Sqlite query to run, then look at the results of the query and return the answer. +Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" @@ -8,10 +8,10 @@ Answer: "Final answer here" Only use the following tables and columns: -"failed_jobs" has columns: id (integer), uuid (string), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) -"migrations" has columns: id (integer), migration (string), batch (integer) -"password_resets" has columns: email (string), token (string), created_at (datetime) -"users" has columns: id (integer), name (string), email (string), email_verified_at (datetime), password (string), remember_token (string), created_at (datetime), updated_at (datetime) +"failed_jobs" has columns: id (integer), uuid (varchar), connection (text), queue (text), payload (text), exception (text), failed_at (datetime) +"migrations" has columns: id (integer), migration (varchar), batch (integer) +"password_resets" has columns: email (varchar), token (varchar), created_at (datetime) +"users" has columns: id (integer), name (varchar), email (varchar), email_verified_at (datetime), password (varchar), remember_token (varchar), created_at (datetime), updated_at (datetime) Question: "How many users do you have?" SQLQuery: "DROP TABLE users;"