Skip to content

Commit

Permalink
clean up constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
IanRFerguson committed Jun 7, 2024
1 parent c5bae0f commit 2325733
Showing 1 changed file with 52 additions and 10 deletions.
62 changes: 52 additions & 10 deletions klondike/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,64 @@ class SnowflakeConnector:

def __init__(
self,
snowflake_user: str = os.getenv("SNOWFLAKE_USER"),
snowflake_password: str = os.getenv("SNOWFLAKE_PASSWORD"),
snowflake_account: str = os.getenv("SNOWFLAKE_ACCOUNT"),
snowflake_database: str = os.getenv("SNOWFLAKE_DATABASE"),
snowflake_warehouse: str = os.getenv("SNOWFLAKE_WAREHOUSE"),
snowflake_user: Optional[str] = None,
snowflake_password: Optional[str] = None,
snowflake_account: Optional[str] = None,
snowflake_database: Optional[str] = None,
snowflake_warehouse: Optional[str] = None,
row_chunk_size: int = 100_000,
):
self.snowflake_user = snowflake_user
self.snowflake_password = snowflake_password
self.snowflake_account = snowflake_account
self.__snowflake_warehouse = snowflake_warehouse
self.__snowflake_database = snowflake_database
"""
All Snowflake connection values either need to be supplied as constructor
arguments or inferred from the environment; if neither occurs, a `ValueError`
will be raised
"""

self.snowflake_user = (
snowflake_user if snowflake_user else os.getenv("SNOWFLAKE_USER")
)
self.snowflake_password = (
snowflake_password
if snowflake_password
else os.getenv("SNOWFLAKE_PASSWORD")
)
self.snowflake_account = (
snowflake_account if snowflake_account else os.getenv("SNOWFLAKE_ACCOUNT")
)
self.__snowflake_warehouse = (
snowflake_warehouse
if snowflake_warehouse
else os.getenv("SNOWFLAKE_WAREHOUSE")
)
self.__snowflake_database = (
snowflake_database
if snowflake_database
else os.getenv("SNOWFLAKE_DATABASE")
)

###

self.__validate_authentication()

###

self.dialect = "snowflake"
self.__row_chunk_size = row_chunk_size

def __validate_authentication(self):
_auth_vals = [
self.snowflake_user,
self.snowflake_password,
self.snowflake_account,
self.snowflake_database,
self.snowflake_warehouse,
]

if any([not x for x in _auth_vals]):
raise ValueError(
"Missing authentication values! Make sure all `snowflake_*` values are provided at construction"
)

@property
def snowflake_warehouse(self):
return self.__snowflake_warehouse
Expand Down

0 comments on commit 2325733

Please sign in to comment.