-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #969 from IEsteve/postgres_tools
Basic PostgreSQL Tool
- Loading branch information
Showing
1 changed file
with
246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
from typing import Optional, Dict, Any | ||
|
||
try: | ||
import psycopg2 | ||
except ImportError: | ||
raise ImportError("`psycopg2` not installed. Please install using `pip install psycopg2`.") | ||
|
||
from phi.tools import Toolkit | ||
from phi.utils.log import logger | ||
|
||
|
||
class PostgresTools(Toolkit): | ||
"""A basic tool to connect to a PostgreSQL database and perform read-only operations on it. | ||
""" | ||
def __init__( | ||
self, | ||
connection: Optional[psycopg2.extensions.connection] = None, | ||
db_name: Optional[str] = None, | ||
user: Optional[str] = None, | ||
password: Optional[str] = None, | ||
host: Optional[str] = None, | ||
port: Optional[int] = None, | ||
run_queries: bool = True, | ||
inspect_queries: bool = False, | ||
summarize_tables: bool = True, | ||
export_tables: bool = False, | ||
): | ||
super().__init__(name="postgres_tools") | ||
self._connection: Optional[psycopg2.extensions.connection] = connection | ||
self.db_name: Optional[str] = db_name | ||
self.user: Optional[str] = user | ||
self.password: Optional[str] = password | ||
self.host: Optional[str] = host | ||
self.port: Optional[int] = port | ||
|
||
|
||
self.register(self.show_tables) | ||
self.register(self.describe_table) | ||
if inspect_queries: | ||
self.register(self.inspect_query) | ||
if run_queries: | ||
self.register(self.run_query) | ||
if summarize_tables: | ||
self.register(self.summarize_table) | ||
if export_tables: | ||
self.register(self.export_table_to_path) | ||
|
||
|
||
@property | ||
def connection(self) -> psycopg2.extensions.connection: | ||
""" | ||
Returns the Postgres psycopg2 connection. | ||
:return psycopg2.extensions.connection: psycopg2 connection | ||
""" | ||
if self._connection is None: | ||
connection_kwargs: Dict[str, Any] = {} | ||
if self.db_name is not None: | ||
connection_kwargs["database"] = self.db_name | ||
if self.user is not None: | ||
connection_kwargs["user"] = self.user | ||
if self.password is not None: | ||
connection_kwargs["password"] = self.password | ||
if self.host is not None: | ||
connection_kwargs["host"] = self.host | ||
if self.port is not None: | ||
connection_kwargs["port"] = self.port | ||
|
||
self._connection = psycopg2.connect(**connection_kwargs) | ||
self._connection.set_session(readonly=True) | ||
|
||
return self._connection | ||
|
||
|
||
def show_tables(self) -> str: | ||
"""Function to show tables in the database | ||
:return: List of tables in the database | ||
""" | ||
stmt = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" | ||
tables = self.run_query(stmt) | ||
logger.debug(f"Tables: {tables}") | ||
return tables | ||
|
||
|
||
def describe_table(self, table: str) -> str: | ||
"""Function to describe a table | ||
:param table: Table to describe | ||
:return: Description of the table | ||
""" | ||
stmt = f"SELECT column_name, data_type, character_maximum_length FROM information_schema.columns WHERE table_name = '{table}';" | ||
table_description = self.run_query(stmt) | ||
|
||
logger.debug(f"Table description: {table_description}") | ||
return f"{table}\n{table_description}" | ||
|
||
|
||
def summarize_table(self, table: str, table_schema: Optional[str] = "public") -> str: | ||
"""Function to compute a number of aggregates over a table. | ||
The function launches a query that computes a number of aggregates over all columns, | ||
including min, max, avg, std and approx_unique. | ||
:param table: Table to summarize | ||
:return: Summary of the table | ||
""" | ||
stmt = f"""WITH column_stats AS ( | ||
SELECT | ||
column_name, | ||
data_type | ||
FROM | ||
information_schema.columns | ||
WHERE | ||
table_name = '{table}' | ||
AND table_schema = '{table_schema}' | ||
) | ||
SELECT | ||
column_name, | ||
data_type, | ||
COUNT(COALESCE(column_name::text, '')) AS non_null_count, | ||
COUNT(*) - COUNT(COALESCE(column_name::text, '')) AS null_count, | ||
SUM(COALESCE(column_name::numeric, 0)) AS sum, | ||
AVG(COALESCE(column_name::numeric, 0)) AS mean, | ||
MIN(column_name::numeric) AS min, | ||
MAX(column_name::numeric) AS max, | ||
STDDEV(COALESCE(column_name::numeric, 0)) AS stddev | ||
FROM | ||
column_stats, | ||
LATERAL ( | ||
SELECT | ||
* | ||
FROM | ||
{table} | ||
) AS tbl | ||
WHERE | ||
data_type IN ('integer', 'numeric', 'real', 'double precision') | ||
GROUP BY | ||
column_name, data_type | ||
UNION ALL | ||
SELECT | ||
column_name, | ||
data_type, | ||
COUNT(COALESCE(column_name::text, '')) AS non_null_count, | ||
COUNT(*) - COUNT(COALESCE(column_name::text, '')) AS null_count, | ||
NULL AS sum, | ||
NULL AS mean, | ||
NULL AS min, | ||
NULL AS max, | ||
NULL AS stddev | ||
FROM | ||
column_stats, | ||
LATERAL ( | ||
SELECT | ||
* | ||
FROM | ||
{table} | ||
) AS tbl | ||
WHERE | ||
data_type NOT IN ('integer', 'numeric', 'real', 'double precision') | ||
GROUP BY | ||
column_name, data_type; | ||
""" | ||
table_summary = self.run_query(stmt) | ||
|
||
logger.debug(f"Table summary: {table_summary}") | ||
return table_summary | ||
|
||
|
||
def inspect_query(self, query: str) -> str: | ||
"""Function to inspect a query and return the query plan. Always inspect your query before running them. | ||
:param query: Query to inspect | ||
:return: Query plan | ||
""" | ||
stmt = f"EXPLAIN {query};" | ||
explain_plan = self.run_query(stmt) | ||
|
||
logger.debug(f"Explain plan: {explain_plan}") | ||
return explain_plan | ||
|
||
|
||
def export_table_to_path(self, table: str, path: Optional[str] = None) -> str: | ||
"""Save a table in CSV format. | ||
If the path is provided, the table will be saved under that path. | ||
Eg: If path is /tmp, the table will be saved as /tmp/table.csv | ||
Otherwise it will be saved in the current directory | ||
:param table: Table to export | ||
:param path: Path to export to | ||
:return: None | ||
""" | ||
|
||
logger.debug(f"Exporting Table {table} as CSV to path {path}") | ||
if path is None: | ||
path = f"{table}.csv" | ||
else: | ||
path = f"{path}/{table}.csv" | ||
|
||
export_statement = f"COPY {table} TO '{path}' DELIMITER ',' CSV HEADER;" | ||
result = self.run_query(export_statement) | ||
logger.debug(f"Exported {table} to {path}/{table}") | ||
|
||
return result | ||
|
||
|
||
def run_query(self, query: str) -> str: | ||
"""Function that runs a query and returns the result. | ||
:param query: SQL query to run | ||
:return: Result of the query | ||
""" | ||
|
||
# -*- Format the SQL Query | ||
# Remove backticks | ||
formatted_sql = query.replace("`", "") | ||
# If there are multiple statements, only run the first one | ||
formatted_sql = formatted_sql.split(";")[0] | ||
|
||
try: | ||
logger.info(f"Running: {formatted_sql}") | ||
|
||
cursor = self.connection.cursor() | ||
cursor.execute(query) | ||
query_result = cursor.fetchall() | ||
|
||
result_output = "No output" | ||
if query_result is not None: | ||
try: | ||
results_as_python_objects = query_result | ||
result_rows = [] | ||
for row in results_as_python_objects: | ||
if len(row) == 1: | ||
result_rows.append(str(row[0])) | ||
else: | ||
result_rows.append(",".join(str(x) for x in row)) | ||
|
||
result_data = "\n".join(result_rows) | ||
result_output = ",".join(query_result.columns) + "\n" + result_data | ||
except AttributeError: | ||
result_output = str(query_result) | ||
|
||
logger.debug(f"Query result: {result_output}") | ||
|
||
return result_output | ||
except Exception as e: | ||
return str(e) |