diff --git a/klondike/snowflake/snowflake.py b/klondike/snowflake/snowflake.py index 8c52f2c..2e6eec3 100644 --- a/klondike/snowflake/snowflake.py +++ b/klondike/snowflake/snowflake.py @@ -34,22 +34,64 @@ class SnowflakeConnector: def __init__( self, - snowflake_user: str = os.getenv("SNOWFLAKE_USER"), - snowflake_password: str = os.getenv("SNOWFLAKE_PASSWORD"), - snowflake_account: str = os.getenv("SNOWFLAKE_ACCOUNT"), - snowflake_database: str = os.getenv("SNOWFLAKE_DATABASE"), - snowflake_warehouse: str = os.getenv("SNOWFLAKE_WAREHOUSE"), + snowflake_user: Optional[str] = None, + snowflake_password: Optional[str] = None, + snowflake_account: Optional[str] = None, + snowflake_database: Optional[str] = None, + snowflake_warehouse: Optional[str] = None, row_chunk_size: int = 100_000, ): - self.snowflake_user = snowflake_user - self.snowflake_password = snowflake_password - self.snowflake_account = snowflake_account - self.__snowflake_warehouse = snowflake_warehouse - self.__snowflake_database = snowflake_database + """ + All Snowflake connection values either need to be supplied as constructor + arguments or inferred from the environment; if neither occurs, a `ValueError` + will be raised + """ + + self.snowflake_user = ( + snowflake_user if snowflake_user else os.getenv("SNOWFLAKE_USER") + ) + self.snowflake_password = ( + snowflake_password + if snowflake_password + else os.getenv("SNOWFLAKE_PASSWORD") + ) + self.snowflake_account = ( + snowflake_account if snowflake_account else os.getenv("SNOWFLAKE_ACCOUNT") + ) + self.__snowflake_warehouse = ( + snowflake_warehouse + if snowflake_warehouse + else os.getenv("SNOWFLAKE_WAREHOUSE") + ) + self.__snowflake_database = ( + snowflake_database + if snowflake_database + else os.getenv("SNOWFLAKE_DATABASE") + ) + + ### + + self.__validate_authentication() + + ### self.dialect = "snowflake" self.__row_chunk_size = row_chunk_size + def __validate_authentication(self): + _auth_vals = [ + self.snowflake_user, + self.snowflake_password, + self.snowflake_account, + self.snowflake_database, + self.snowflake_warehouse, + ] + + if any([not x for x in _auth_vals]): + raise ValueError( + "Missing authentication values! Make sure all `snowflake_*` values are provided at construction" + ) + @property def snowflake_warehouse(self): return self.__snowflake_warehouse