Skip to content

Commit

Permalink
Remove TempPostgresHook (#95)
Browse files Browse the repository at this point in the history
* Remove TempPostgresHook

* more removals
  • Loading branch information
dimberman authored Feb 9, 2022
1 parent 6b92b49 commit 1cfd3bc
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 44 deletions.
5 changes: 3 additions & 2 deletions src/astro/sql/operators/agnostic_save_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, DagRun, TaskInstance
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from google.cloud.storage import Client
from smart_open import open

from astro.sql.operators.temp_hooks import TempPostgresHook, TempSnowflakeHook
from astro.sql.operators.temp_hooks import TempSnowflakeHook
from astro.sql.table import Table
from astro.utils.cloud_storage_creds import gcs_client, s3fs_creds
from astro.utils.schema_util import get_table_name
Expand Down Expand Up @@ -81,7 +82,7 @@ def execute(self, context):

# Select database Hook based on `conn` type
input_hook = {
"postgres": TempPostgresHook(
"postgres": PostgresHook(
postgres_conn_id=input_table.conn_id, schema=input_table.database
),
"snowflake": TempSnowflakeHook(
Expand Down
29 changes: 0 additions & 29 deletions src/astro/sql/operators/temp_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import warnings
from typing import Dict, Optional, Sequence, Union
from urllib.parse import quote_plus

from airflow.hooks.dbapi import DbApiHook
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook

Expand All @@ -37,27 +32,3 @@ def get_uri(self) -> str:
"?warehouse={warehouse}&role={role}&authenticator={authenticator}"
)
return uri.format(**conn_config)


class TempPostgresHook(PostgresHook):
"""
Temporary class to get around a bug in the snowflakehook when creating URIs
"""

def get_uri(self) -> str:
"""
Extract the URI from the connection.
:return: the extracted uri.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
login = ""
if conn.login:
login = f"{quote_plus(conn.login)}:{quote_plus(conn.password)}@"
host = conn.host
if conn.port is not None:
host += f":{conn.port}"
uri = f"postgresql://{login}{host}/"
if self.schema:
uri += self.schema
return uri
9 changes: 5 additions & 4 deletions src/astro/utils/load_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from typing import Optional, Union

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from pandas import DataFrame
from pandas.io.sql import SQLDatabase
from snowflake.connector.pandas_tools import write_pandas

from astro.sql.operators.temp_hooks import TempPostgresHook, TempSnowflakeHook
from astro.sql.operators.temp_hooks import TempSnowflakeHook
from astro.utils.schema_util import set_schema_query


Expand All @@ -36,9 +37,9 @@ def move_dataframe_to_sql(
chunksize=None,
):
# Select database Hook based on `conn` type
hook: Union[TempPostgresHook, TempSnowflakeHook] = { # type: ignore
"postgresql": TempPostgresHook(postgres_conn_id=conn_id, schema=database),
"postgres": TempPostgresHook(postgres_conn_id=conn_id, schema=database),
hook: Union[PostgresHook, TempSnowflakeHook] = { # type: ignore
"postgresql": PostgresHook(postgres_conn_id=conn_id, schema=database),
"postgres": PostgresHook(postgres_conn_id=conn_id, schema=database),
"snowflake": TempSnowflakeHook(
snowflake_conn_id=conn_id,
database=database,
Expand Down
13 changes: 7 additions & 6 deletions tests/operators/test_agnostic_load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from airflow.models import DAG, DagRun
from airflow.models import TaskInstance as TI
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook

# from astro.sql.operators.temp_hooks import PostgresHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from airflow.utils import timezone
Expand All @@ -47,7 +49,6 @@

# Import Operator
from astro.sql.operators.agnostic_load_file import AgnosticLoadFile, load_file
from astro.sql.operators.temp_hooks import TempPostgresHook
from astro.sql.table import Table, TempTable
from tests.operators import utils as test_utils

Expand Down Expand Up @@ -234,7 +235,7 @@ def test_unique_task_id_for_same_path(self):
def test_aql_local_file_to_postgres_no_table_name(self):
OUTPUT_TABLE_NAME = "expected_table_from_csv"

self.hook_target = TempPostgresHook(
self.hook_target = PostgresHook(
postgres_conn_id="postgres_conn", schema="pagila"
)

Expand Down Expand Up @@ -299,7 +300,7 @@ def test_aql_local_file_to_bigquery_no_table_name(self):
def test_aql_overwrite_existing_table(self):
OUTPUT_TABLE_NAME = "expected_table_from_csv"

self.hook_target = TempPostgresHook(
self.hook_target = PostgresHook(
postgres_conn_id="postgres_conn", schema="pagila"
)

Expand Down Expand Up @@ -341,7 +342,7 @@ def test_aql_overwrite_existing_table(self):
def test_aql_s3_file_to_postgres(self):
OUTPUT_TABLE_NAME = "expected_table_from_s3_csv"

self.hook_target = TempPostgresHook(
self.hook_target = PostgresHook(
postgres_conn_id="postgres_conn", schema="pagila"
)

Expand Down Expand Up @@ -373,7 +374,7 @@ def test_aql_s3_file_to_postgres(self):
def test_aql_s3_file_to_postgres_no_table_name(self):
OUTPUT_TABLE_NAME = "test_dag_load_file_homes_csv_2"

self.hook_target = TempPostgresHook(
self.hook_target = PostgresHook(
postgres_conn_id="postgres_conn", schema="pagila"
)

Expand Down Expand Up @@ -407,7 +408,7 @@ def test_aql_s3_file_to_postgres_no_table_name(self):
def test_aql_s3_file_to_postgres_specify_schema(self):
OUTPUT_TABLE_NAME = "expected_table_from_s3_csv"

self.hook_target = TempPostgresHook(
self.hook_target = PostgresHook(
postgres_conn_id="postgres_conn", schema="pagila"
)

Expand Down
4 changes: 1 addition & 3 deletions tests/operators/test_postgres_append.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import unittest.mock

import pandas as pd
import pytest
from airflow.executors.debug_executor import DebugExecutor
from airflow.models import DAG, DagRun
from airflow.models import TaskInstance as TI
Expand All @@ -44,7 +43,6 @@

# Import Operator
import astro.sql as aql
from astro.sql.operators.temp_hooks import TempPostgresHook
from astro.sql.table import Table

# from tests.operators import utils as test_utils
Expand Down Expand Up @@ -185,7 +183,7 @@ def test_append(self):

def test_append_all_fields(self):

hook = TempPostgresHook(postgres_conn_id="postgres_conn", schema="pagila")
hook = PostgresHook(postgres_conn_id="postgres_conn", schema="pagila")

drop_table(table_name="test_main", postgres_conn=hook.get_conn())
drop_table(table_name="test_append", postgres_conn=hook.get_conn())
Expand Down

0 comments on commit 1cfd3bc

Please sign in to comment.