Skip to content

Commit

Permalink
Refactor 'identifiers_as_lower' to 'columns_names_capitalization' (#567)
Browse files Browse the repository at this point in the history
Closes: #564

Move towards a consistency strategy to change column names' capitalisation for both dataframes and files loaded using `load_file`.
* Expose 'columns_names_capitalization' both in `aql.load_file` and `aql.dataframe`
* Remove 'identifiers_as_lower' from `aql.dataframe`
  • Loading branch information
utkarsharma2 authored Jul 26, 2022
1 parent 5ec3057 commit 0b9afa3
Show file tree
Hide file tree
Showing 16 changed files with 233 additions and 76 deletions.
2 changes: 1 addition & 1 deletion example_dags/example_amazon_s3_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def sample_create_table(input_table: Table):
return "SELECT * FROM {{input_table}} LIMIT 10"


@aql.dataframe(identifiers_as_lower=False)
@aql.dataframe(columns_names_capitalization="original")
def my_df_func(input_df: DataFrame):
print(input_df)

Expand Down
2 changes: 1 addition & 1 deletion example_dags/example_amazon_s3_snowflake_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def clean_data(input_table: Table):
"""


@aql.dataframe()
@aql.dataframe(columns_names_capitalization="original")
def aggregate_data(df: pd.DataFrame):
new_df = df.pivot_table(
index="date", values="name", columns=["type"], aggfunc="count"
Expand Down
2 changes: 1 addition & 1 deletion example_dags/example_google_bigquery_gcs_load_and_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)

# Setting "identifiers_as_lower" to True will lowercase all column names
@aql.dataframe(identifiers_as_lower=False)
@aql.dataframe(columns_names_capitalization="original")
def extract_top_5_movies(input_df: pd.DataFrame):
print(f"Total Number of records: {len(input_df)}")
top_5_movies = input_df.sort_values(by="Rating", ascending=False)[
Expand Down
2 changes: 2 additions & 0 deletions src/astro/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ class Database(Enum):

# TODO: check how snowflake names these
MergeConflictStrategy = Literal["ignore", "update", "exception"]

ColumnCapitalization = Literal["upper", "lower", "original"]
22 changes: 19 additions & 3 deletions src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from astro.constants import (
DEFAULT_CHUNK_SIZE,
ColumnCapitalization,
ExportExistsStrategy,
LoadExistStrategy,
MergeConflictStrategy,
Expand Down Expand Up @@ -183,13 +184,16 @@ def create_table_using_schema_autodetection(
table: Table,
file: Optional[File] = None,
dataframe: Optional[pd.DataFrame] = None,
columns_names_capitalization: ColumnCapitalization = "lower", # skipcq
) -> None:
"""
Create a SQL table, automatically inferring the schema using the given file.
:param table: The table to be created.
:param file: File used to infer the new table columns.
:param dataframe: Dataframe used to infer the new table columns if there is no file
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
if file is None:
if dataframe is None:
Expand All @@ -201,6 +205,7 @@ def create_table_using_schema_autodetection(
source_dataframe = file.export_to_dataframe(
nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT
)

db = SQLDatabase(engine=self.sqlalchemy_engine)
db.prep_table(
source_dataframe,
Expand All @@ -215,6 +220,7 @@ def create_table(
table: Table,
file: Optional[File] = None,
dataframe: Optional[pd.DataFrame] = None,
columns_names_capitalization: ColumnCapitalization = "original",
) -> None:
"""
Create a table either using its explicitly defined columns or inferring
Expand All @@ -223,12 +229,15 @@ def create_table(
:param table: The table to be created
:param file: (optional) File used to infer the table columns.
:param dataframe: (optional) Dataframe used to infer the new table columns if there is no file
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
if table.columns:
self.create_table_using_columns(table)
else:
self.create_table_using_schema_autodetection(table, file, dataframe)
self.create_table_using_schema_autodetection(
table, file, dataframe, columns_names_capitalization
)

def create_table_from_select_statement(
self,
Expand Down Expand Up @@ -272,6 +281,7 @@ def load_file_to_table(
chunk_size: int = DEFAULT_CHUNK_SIZE,
use_native_support: bool = True,
native_support_kwargs: Optional[Dict] = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: Optional[bool] = True,
**kwargs,
):
Expand All @@ -286,6 +296,8 @@ def load_file_to_table(
:param use_native_support: Use native support for data transfer if available on the destination
:param normalize_config: pandas json_normalize params config
:param native_support_kwargs: kwargs to be used by method involved in native support flow
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
"""
normalize_config = normalize_config or {}
Expand All @@ -297,7 +309,11 @@ def load_file_to_table(
self.create_schema_if_needed(output_table.metadata.schema)
if if_exists == "replace" or not self.table_exists(output_table):
self.drop_table(output_table)
self.create_table(output_table, input_files[0])
self.create_table(
output_table,
input_files[0],
columns_names_capitalization=columns_names_capitalization,
)
if_exists = "append"

# TODO: many native transfers support the input_file.path - it may be better
Expand Down
14 changes: 10 additions & 4 deletions src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from astro import settings
from astro.constants import (
DEFAULT_CHUNK_SIZE,
ColumnCapitalization,
FileLocation,
FileType,
LoadExistStrategy,
Expand Down Expand Up @@ -333,6 +334,7 @@ def create_table_using_schema_autodetection(
table: Table,
file: Optional[File] = None,
dataframe: Optional[pd.DataFrame] = None,
columns_names_capitalization: ColumnCapitalization = "lower",
) -> None:
"""
Create a SQL table, automatically inferring the schema using the given file.
Expand All @@ -341,16 +343,20 @@ def create_table_using_schema_autodetection(
:param file: File used to infer the new table columns.
:param dataframe: Dataframe used to infer the new table columns if there is no file
"""

# Snowflake don't expect mixed case col names like - 'Title' or 'Category'
# we explicitly convert them to lower case, if not provided by user
if columns_names_capitalization not in ["lower", "upper"]:

This comment has been minimized.

Copy link
@tatiana

tatiana Jul 27, 2022

Collaborator

Perhaps we should raise an exception if the user manually sets columns_names_capitalization=original here, WDYT?

columns_names_capitalization = "lower"

if file:
dataframe = file.export_to_dataframe(
nrows=settings.LOAD_TABLE_AUTODETECT_ROWS_COUNT
nrows=settings.LOAD_TABLE_AUTODETECT_ROWS_COUNT,
columns_names_capitalization=columns_names_capitalization,
)

# Snowflake doesn't handle well mixed capitalisation of column name chars
# we are handling this more gracefully in a separate PR
if dataframe is not None:
dataframe.columns.str.upper()

super().create_table_using_schema_autodetection(table, dataframe=dataframe)

def is_native_load_file_available(
Expand Down
13 changes: 11 additions & 2 deletions src/astro/files/types/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@

from astro.constants import FileType as FileTypeConstants
from astro.files.types.base import FileType
from astro.utils.dataframe import convert_columns_names_capitalization


class CSVFileType(FileType):
"""Concrete implementation to handle CSV file type"""

def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame:
def export_to_dataframe(
self, stream, columns_names_capitalization="original", **kwargs
) -> pd.DataFrame:
"""read csv file from one of the supported locations and return dataframe
:param stream: file stream object
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
return pd.read_csv(stream, **kwargs)
df = pd.read_csv(stream, **kwargs)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
)
return df

def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None:
"""Write csv file to one of the supported locations
Expand Down
16 changes: 14 additions & 2 deletions src/astro/files/types/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,32 @@

from astro.constants import FileType as FileTypeConstants
from astro.files.types.base import FileType
from astro.utils.dataframe import convert_columns_names_capitalization


class JSONFileType(FileType):
"""Concrete implementation to handle JSON file type"""

def export_to_dataframe(self, stream: io.TextIOWrapper, **kwargs):
def export_to_dataframe(
self,
stream: io.TextIOWrapper,
columns_names_capitalization="original",
**kwargs
) -> pd.DataFrame:
"""read json file from one of the supported locations and return dataframe
:param stream: file stream object
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
kwargs_copy = dict(kwargs)
# Pandas `read_json` does not support the `nrows` parameter unless we're using NDJSON
kwargs_copy.pop("nrows", None)
return pd.read_json(stream, **kwargs_copy)
df = pd.read_json(stream, **kwargs_copy)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
)
return df

def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None:
"""Write json file to one of the supported locations
Expand Down
13 changes: 11 additions & 2 deletions src/astro/files/types/ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@
from astro.constants import DEFAULT_CHUNK_SIZE
from astro.constants import FileType as FileTypeConstants
from astro.files.types.base import FileType
from astro.utils.dataframe import convert_columns_names_capitalization


class NDJSONFileType(FileType):
"""Concrete implementation to handle NDJSON file type"""

def export_to_dataframe(self, stream, **kwargs):
def export_to_dataframe(
self, stream, columns_names_capitalization="original", **kwargs
):
"""read ndjson file from one of the supported locations and return dataframe
:param stream: file stream object
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
return NDJSONFileType.flatten(self.normalize_config, stream, **kwargs)
df = NDJSONFileType.flatten(self.normalize_config, stream, **kwargs)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
)
return df

def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None:
"""Write ndjson file to one of the supported locations
Expand Down
14 changes: 12 additions & 2 deletions src/astro/files/types/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,30 @@

from astro.constants import FileType as FileTypeConstants
from astro.files.types.base import FileType
from astro.utils.dataframe import convert_columns_names_capitalization


class ParquetFileType(FileType):
"""Concrete implementation to handle Parquet file type"""

def export_to_dataframe(self, stream, **kwargs):
def export_to_dataframe(
self, stream, columns_names_capitalization="original", **kwargs
):
"""read parquet file from one of the supported locations and return dataframe
:param stream: file stream object
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
kwargs_copy = dict(kwargs)
# Pandas `read_parquet` does not support the `nrows` parameter
kwargs_copy.pop("nrows", None)
return pd.read_parquet(stream, **kwargs_copy)

df = pd.read_parquet(stream, **kwargs_copy)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
)
return df

def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None:
"""Write parquet file to one of the supported locations
Expand Down
6 changes: 3 additions & 3 deletions src/astro/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from airflow.decorators.base import task_decorator_factory
from airflow.decorators import _TaskDecorator as TaskDecorator

from astro.constants import MergeConflictStrategy
from astro.constants import ColumnCapitalization, MergeConflictStrategy
from astro.sql.operators.append import APPEND_COLUMN_TYPE, AppendOperator
from astro.sql.operators.cleanup import CleanupOperator
from astro.sql.operators.dataframe import DataframeOperator
Expand Down Expand Up @@ -245,7 +245,7 @@ def dataframe(
database: Optional[str] = None,
schema: Optional[str] = None,
task_id: Optional[str] = None,
identifiers_as_lower: Optional[bool] = True,
columns_names_capitalization: ColumnCapitalization = "lower",
) -> Callable[..., pd.DataFrame]:
"""
This decorator will allow users to write python functions while treating SQL tables as dataframes
Expand All @@ -257,7 +257,7 @@ def dataframe(
"conn_id": conn_id,
"database": database,
"schema": schema,
"identifiers_as_lower": identifiers_as_lower,
"columns_names_capitalization": columns_names_capitalization,
}
if task_id:
param_map["task_id"] = task_id
Expand Down
Loading

0 comments on commit 0b9afa3

Please sign in to comment.