Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bi 5098: filter partitions and system tables for PG and GP #189

Merged
merged 5 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions lib/dl_connector_greenplum/dl_connector_greenplum/core/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import List

import sqlalchemy as sa

from dl_core.connection_executors.models.db_adapter_data import RawSchemaInfo
from dl_core.connection_models import TableDefinition
from dl_core.connection_models.common_models import DBIdent

from dl_connector_postgresql.core.postgresql_base.adapters_postgres import PostgresAdapter
from dl_connector_postgresql.core.postgresql_base.async_adapters_postgres import AsyncPostgresAdapter


GP_LIST_SOURCES_ALL_SCHEMAS_SQL = """
SELECT
pg_namespace.nspname as schema,
pg_class.relname as name
FROM
pg_class
JOIN pg_namespace
ON pg_namespace.oid = pg_class.relnamespace
LEFT JOIN pg_partitions
ON pg_partitions.partitiontablename = pg_class.relname
WHERE
pg_namespace.nspname not like 'pg_%'
AND pg_namespace.nspname not like 'gp_%'
AND pg_namespace.nspname != 'session_state'
AND pg_namespace.nspname != 'information_schema'
AND pg_class.relkind in ('m', 'p', 'r', 'v')
AND pg_partitions.tablename is NULL
ORDER BY schema, name;
"""


GP_LIST_SCHEMA_NAMES = """
SELECT nspname FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%'
AND nspname NOT LIKE 'gp_%'
AND nspname != 'session_state'
AND nspname != 'information_schema'
ORDER BY nspname
"""


GP_LIST_TABLE_NAMES = """
SELECT c.relname
FROM
pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN pg_partitions p ON p.partitiontablename = c.relname
WHERE
n.nspname = :schema
AND c.relkind in ('r', 'p')
AND p.tablename is NULL
"""


class GreenplumAdapter(PostgresAdapter):
_LIST_ALL_TABLES_QUERY = GP_LIST_SOURCES_ALL_SCHEMAS_SQL
_LIST_TABLE_NAMES_QUERY = GP_LIST_TABLE_NAMES
_LIST_SCHEMA_NAMES_QUERY = GP_LIST_SCHEMA_NAMES

def _get_schema_names(self, db_ident: DBIdent) -> List[str]:
db_engine = self.get_db_engine(db_ident.db_name)
table_list = [table_name for table_name, in db_engine.execute(sa.text(self._LIST_SCHEMA_NAMES_QUERY))]
return table_list
ovsds marked this conversation as resolved.
Show resolved Hide resolved


class AsyncGreenplumAdapter(AsyncPostgresAdapter):
_LIST_ALL_TABLES_QUERY = GP_LIST_SOURCES_ALL_SCHEMAS_SQL
_LIST_SCHEMA_NAMES_QUERY = GP_LIST_SCHEMA_NAMES
_LIST_TABLE_NAMES_QUERY = GP_LIST_TABLE_NAMES

async def get_table_info(self, table_def: TableDefinition, fetch_idx_info: bool) -> RawSchemaInfo:
raise NotImplementedError()
ovsds marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dl_connector_greenplum.core.adapters import (
AsyncGreenplumAdapter,
GreenplumAdapter,
)
from dl_connector_postgresql.core.postgresql_base.connection_executors import (
AsyncPostgresConnExecutor,
PostgresConnExecutor,
)


class GreenplumConnExecutor(PostgresConnExecutor):
TARGET_ADAPTER_CLS = GreenplumAdapter


class AsyncGreenplumConnExecutor(AsyncPostgresConnExecutor):
TARGET_ADAPTER_CLS = AsyncGreenplumAdapter
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
SQLTableCoreSourceDefinitionBase,
)

from dl_connector_greenplum.core.adapters import (
AsyncGreenplumAdapter,
GreenplumAdapter,
)
from dl_connector_greenplum.core.connection_executors import (
AsyncGreenplumConnExecutor,
GreenplumConnExecutor,
)
from dl_connector_greenplum.core.constants import (
BACKEND_TYPE_GREENPLUM,
CONNECTION_TYPE_GREENPLUM,
Expand All @@ -21,12 +29,6 @@
from dl_connector_greenplum.core.data_source_migration import GreenPlumDataSourceMigrator
from dl_connector_greenplum.core.storage_schemas.connection import GreenplumConnectionDataStorageSchema
from dl_connector_greenplum.core.us_connection import GreenplumConnection
from dl_connector_postgresql.core.postgresql_base.adapters_postgres import PostgresAdapter
from dl_connector_postgresql.core.postgresql_base.async_adapters_postgres import AsyncPostgresAdapter
from dl_connector_postgresql.core.postgresql_base.connection_executors import (
AsyncPostgresConnExecutor,
PostgresConnExecutor,
)
from dl_connector_postgresql.core.postgresql_base.query_compiler import PostgreSQLQueryCompiler
from dl_connector_postgresql.core.postgresql_base.sa_types import SQLALCHEMY_POSTGRES_TYPES
from dl_connector_postgresql.core.postgresql_base.type_transformer import PostgreSQLTypeTransformer
Expand All @@ -37,8 +39,8 @@ class GreenplumCoreConnectionDefinition(CoreConnectionDefinition):
connection_cls = GreenplumConnection
us_storage_schema_cls = GreenplumConnectionDataStorageSchema
type_transformer_cls = PostgreSQLTypeTransformer
sync_conn_executor_cls = PostgresConnExecutor
async_conn_executor_cls = AsyncPostgresConnExecutor
sync_conn_executor_cls = GreenplumConnExecutor
async_conn_executor_cls = AsyncGreenplumConnExecutor
dialect_string = "bi_postgresql"
data_source_migrator_cls = GreenPlumDataSourceMigrator

Expand All @@ -65,5 +67,5 @@ class GreenplumCoreConnector(CoreConnector):
GreenplumTableCoreSourceDefinition,
GreenplumSubselectCoreSourceDefinition,
)
rqe_adapter_classes = frozenset({PostgresAdapter, AsyncPostgresAdapter})
rqe_adapter_classes = frozenset({GreenplumAdapter, AsyncGreenplumAdapter})
sa_types = SQLALCHEMY_POSTGRES_TYPES
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,18 @@
pg_namespace.nspname not like 'pg_%'
AND pg_namespace.nspname != 'information_schema'
AND pg_class.relkind in ('m', 'p', 'r', 'v')
AND NOT COALESCE((row_to_json(pg_class)->>'relispartition')::boolean, false)
ORDER BY schema, name
""".strip().replace(
"\n", " "
)
# NOTE: there's also `AND NOT pg_class.relispartition` in postgresql>=10
# NOTE: pg_class.relispartition` field exists only for postgresql>=10, so for postgresql<10 support json is used here
PG_LIST_TABLE_NAMES = """
SELECT c.relname FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema AND c.relkind in ('r', 'p')
AND NOT COALESCE((row_to_json(c)->>'relispartition')::boolean, false);
"""


@attr.s(cmp=False)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import itertools
import typing
from typing import (
TYPE_CHECKING,
Expand All @@ -21,6 +22,7 @@
from dl_connector_postgresql.core.postgresql_base.adapters_base_postgres import (
OID_KNOWLEDGE,
PG_LIST_SOURCES_ALL_SCHEMAS_SQL,
PG_LIST_TABLE_NAMES,
BasePostgresAdapter,
)
from dl_connector_postgresql.core.postgresql_base.error_transformer import sync_pg_db_error_transformer
Expand All @@ -40,6 +42,9 @@ class PostgresAdapter(BasePostgresAdapter, BaseClassicAdapter[PostgresConnTarget
"stream_results": True,
}

_LIST_ALL_TABLES_QUERY = PG_LIST_SOURCES_ALL_SCHEMAS_SQL
_LIST_TABLE_NAMES_QUERY = PG_LIST_TABLE_NAMES

def get_connect_args(self) -> dict:
return dict(
super().get_connect_args(),
Expand Down Expand Up @@ -70,16 +75,22 @@ def execution_context(self) -> typing.Generator[None, None, None]:
stack.close()

def _get_tables(self, schema_ident: SchemaIdent) -> List[TableIdent]:
db_name = schema_ident.db_name
db_engine = self.get_db_engine(db_name)

if schema_ident.schema_name is not None:
# For a single schema, plug into the common SA code.
# (might not be ever used)
return super()._get_tables(schema_ident)

assert schema_ident.schema_name is None
db_name = schema_ident.db_name
db_engine = self.get_db_engine(db_name)
query = PG_LIST_SOURCES_ALL_SCHEMAS_SQL
result = db_engine.execute(sa.text(query))
db_engine = self.get_db_engine(schema_ident.db_name)
table_list = db_engine.execute(sa.text(self._LIST_TABLE_NAMES_QUERY))
view_list = sa.inspect(db_engine).get_view_names(schema=schema_ident.schema_name)

result = ((schema_ident.schema_name, name) for name in itertools.chain(table_list, view_list))
else:
assert schema_ident.schema_name is None
result = db_engine.execute(sa.text(self._LIST_ALL_TABLES_QUERY))

return [
TableIdent(
db_name=db_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from dl_connector_postgresql.core.postgresql_base.adapters_base_postgres import (
OID_KNOWLEDGE,
PG_LIST_SOURCES_ALL_SCHEMAS_SQL,
PG_LIST_TABLE_NAMES,
BasePostgresAdapter,
)
from dl_connector_postgresql.core.postgresql_base.error_transformer import make_async_pg_error_transformer
Expand All @@ -84,17 +85,11 @@
ORDER BY nspname
"""

PG_LIST_TABLE_NAMES = """
SELECT c.relname FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema AND c.relkind in ('r', 'p')
"""

# https://github.com/sqlalchemy/sqlalchemy/blob/rel_1_4/lib/sqlalchemy/dialects/postgresql/base.py#L3802
PG_LIST_VIEW_NAMES = """
SELECT c.relname FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema AND c.relkind IN 'v', 'm'
WHERE n.nspname = :schema AND c.relkind IN ('v', 'm')
"""


Expand Down Expand Up @@ -132,6 +127,11 @@ class AsyncPostgresAdapter(
OSError,
)

_LIST_ALL_TABLES_QUERY = PG_LIST_SOURCES_ALL_SCHEMAS_SQL
_LIST_SCHEMA_NAMES_QUERY = PG_LIST_SCHEMA_NAMES
_LIST_TABLE_NAMES_QUERY = PG_LIST_TABLE_NAMES
_LIST_VIEW_NAMES_QUERY = PG_LIST_VIEW_NAMES

@property
def _dialect(self) -> AsyncBIPGDialect:
if self.__dialect is not None:
Expand Down Expand Up @@ -314,54 +314,39 @@ async def _process_chunk(steps: AsyncIterator[ExecutionStep]) -> TBIChunksGen:
)

async def get_schema_names(self, db_ident: DBIdent) -> list[str]:
result = await self.execute(DBAdapterQuery(PG_LIST_SCHEMA_NAMES))
result = await self.execute(DBAdapterQuery(self._LIST_SCHEMA_NAMES_QUERY))
schema_names = []
async for row in result.get_all_rows():
for value in row:
schema_names.append(str(value))
return schema_names

async def _get_view_names(self, schema_ident: SchemaIdent) -> list[TableIdent]:
query = sa.text(PG_LIST_VIEW_NAMES).bindparams(
async def _get_relation_names(self, schema_ident: SchemaIdent, get_query: str) -> list[TableIdent]:
query = sa.text(get_query).bindparams(
sa.bindparam(
"schema",
schema_ident.schema_name,
type=sa.types.Unicode,
type_=sa.types.Unicode,
)
)
result = await self.execute(DBAdapterQuery(query))
views = []
relations = []
async for row in result.get_all_rows():
views.append(str(row[0]))
relations.append(str(row[0]))
return [
TableIdent(
db_name=schema_ident.db_name,
schema_name=schema_ident.schema_name,
table_name=view,
table_name=rel,
)
for view in views
for rel in relations
]

async def _get_view_names(self, schema_ident: SchemaIdent) -> list[TableIdent]:
return await self._get_relation_names(schema_ident, self._LIST_VIEW_NAMES_QUERY)

async def _get_table_names(self, schema_ident: SchemaIdent) -> list[TableIdent]:
query = sa.text(PG_LIST_TABLE_NAMES).bindparams(
sa.bindparam(
"schema",
schema_ident.schema_name,
type=sa.types.Unicode,
)
)
result = await self.execute(DBAdapterQuery(query))
tables = []
async for row in result.get_all_rows():
tables.append(str(row[0]))
return [
TableIdent(
db_name=schema_ident.db_name,
schema_name=schema_ident.schema_name,
table_name=table,
)
for table in tables
]
return await self._get_relation_names(schema_ident, self._LIST_TABLE_NAMES_QUERY)

async def _get_tables_single_schema(self, schema_ident: SchemaIdent) -> list[TableIdent]:
table_list = await self._get_table_names(schema_ident)
Expand All @@ -376,7 +361,7 @@ async def get_tables(self, schema_ident: SchemaIdent) -> list[TableIdent]:

assert schema_ident.schema_name is None
db_name = schema_ident.db_name
result = await self.execute(DBAdapterQuery(PG_LIST_SOURCES_ALL_SCHEMAS_SQL))
result = await self.execute(DBAdapterQuery(self._LIST_ALL_TABLES_QUERY))
return [
TableIdent(
db_name=db_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import pytest

from dl_api_commons.base_models import RequestContextInfo
from dl_core.connection_models.common_models import SchemaIdent
from dl_core_testing.testcases.adapter import BaseAsyncAdapterTestClass
from dl_testing.regulated_test import RegulatedTestParams

Expand All @@ -17,3 +21,29 @@ class TestAsyncPostgreSQLAdapter(
)

ASYNC_ADAPTER_CLS = AsyncPostgresAdapter

async def test_tables_list(self, conn_bi_context: RequestContextInfo, target_conn_dto: PostgresConnTargetDTO):
tables = await self._make_dba(target_conn_dto, conn_bi_context).get_tables(
SchemaIdent(db_name="test_data", schema_name=None)
)

assert [f"{t.schema_name}.{t.table_name}" for t in tables] == [
"test_data.sample",
"test_data_partitions.sample_partition",
]

@pytest.mark.parametrize(
"schema, expected_tables",
[
("test_data", ["sample"]),
("test_data_partitions", ["sample_partition"]),
],
)
async def test_tables_list_schema(
self, conn_bi_context: RequestContextInfo, target_conn_dto: PostgresConnTargetDTO, schema, expected_tables
):
tables = await self._make_dba(target_conn_dto, conn_bi_context).get_tables(
SchemaIdent(db_name="test_data", schema_name=schema)
)

assert [f"{t.schema_name}.{t.table_name}" for t in tables] == [f"{schema}.{t}" for t in expected_tables]
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
CREATE SCHEMA IF NOT EXISTS test_data_partitions;
DROP TABLE IF EXISTS test_data_partitions.sample_partition;
create table test_data_partitions.sample_partition (
id int not NULL,
md5 TEXT not NULL
) PARTITION BY Range (id);

CREATE TABLE test_data_partitions.sample_partition_pt_1 PARTITION OF test_data_partitions.sample_partition
FOR VALUES FROM (0) TO (50);

CREATE TABLE test_data_partitions.sample_partition_pt_2 PARTITION OF test_data_partitions.sample_partition
FOR VALUES FROM (50) TO (101);


insert into test_data_partitions.sample_partition (
select id, md5(random()::text) from generate_Series(1, 100) as id
);
Loading