From 509440eff7d45dc0bdda03a685a8e0aee3e450e2 Mon Sep 17 00:00:00 2001 From: Doug Beatty <44704949+dbeatty10@users.noreply.github.com> Date: Sun, 5 Feb 2023 20:36:02 -0700 Subject: [PATCH] Apply code quality formatting rules (#140) --- .bumpversion.cfg | 5 +- MANIFEST.in | 2 +- dbt/adapters/mariadb/__init__.py | 3 +- dbt/adapters/mariadb/column.py | 4 +- dbt/adapters/mariadb/connections.py | 34 +++-- dbt/adapters/mariadb/impl.py | 119 +++++++++--------- dbt/adapters/mariadb/relation.py | 2 +- dbt/adapters/mysql/__init__.py | 5 +- dbt/adapters/mysql/column.py | 4 +- dbt/adapters/mysql/connections.py | 34 +++-- dbt/adapters/mysql/impl.py | 119 +++++++++--------- dbt/adapters/mysql/relation.py | 2 +- dbt/adapters/mysql5/__init__.py | 5 +- dbt/adapters/mysql5/column.py | 4 +- dbt/adapters/mysql5/connections.py | 34 +++-- dbt/adapters/mysql5/impl.py | 119 +++++++++--------- dbt/adapters/mysql5/relation.py | 2 +- dbt/include/mariadb/__init__.py | 1 + .../macros/materializations/test/test.sql | 2 +- dbt/include/mysql/__init__.py | 1 + .../macros/materializations/test/test.sql | 2 +- dbt/include/mysql5/__init__.py | 1 + .../macros/materializations/test/test.sql | 2 +- setup.py | 20 +-- tests/conftest.py | 30 ++--- tests/functional/adapter/test_basic.py | 4 +- tests/unit/test_adapter.py | 44 ++++--- tests/unit/utils.py | 54 ++++---- 28 files changed, 331 insertions(+), 327 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 184be8b..3798f0a 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -8,13 +8,13 @@ parse = (?P\d+) ((?P[a-z]+) ?(\.)? (?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}{prerelease}{num} {major}.{minor}.{patch} [bumpversion:part:prerelease] first_value = a -values = +values = a b rc @@ -33,4 +33,3 @@ replace = version = "{new_version}" [bumpversion:file:dbt/adapters/mariadb/__version__.py] search = version = "{current_version}" replace = version = "{new_version}" - diff --git a/MANIFEST.in b/MANIFEST.in index 78412d5..cfbc714 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include dbt/include *.sql *.yml *.md \ No newline at end of file +recursive-include dbt/include *.sql *.yml *.md diff --git a/dbt/adapters/mariadb/__init__.py b/dbt/adapters/mariadb/__init__.py index 3896fd8..820e9d1 100644 --- a/dbt/adapters/mariadb/__init__.py +++ b/dbt/adapters/mariadb/__init__.py @@ -11,4 +11,5 @@ Plugin = AdapterPlugin( adapter=MariaDBAdapter, credentials=MariaDBCredentials, - include_path=mariadb.PACKAGE_PATH) + include_path=mariadb.PACKAGE_PATH, +) diff --git a/dbt/adapters/mariadb/column.py b/dbt/adapters/mariadb/column.py index a6fe515..c7c47eb 100644 --- a/dbt/adapters/mariadb/column.py +++ b/dbt/adapters/mariadb/column.py @@ -3,7 +3,7 @@ from dbt.adapters.base.column import Column -Self = TypeVar('Self', bound='MariaDBColumn') +Self = TypeVar("Self", bound="MariaDBColumn") @dataclass @@ -18,7 +18,7 @@ class MariaDBColumn(Column): @property def quoted(self) -> str: - return '`{}`'.format(self.column) + return "`{}`".format(self.column) def __repr__(self) -> str: return "".format(self.name, self.data_type) diff --git a/dbt/adapters/mariadb/connections.py b/dbt/adapters/mariadb/connections.py index ebd749e..d85ec29 100644 --- a/dbt/adapters/mariadb/connections.py +++ b/dbt/adapters/mariadb/connections.py @@ -39,10 +39,7 @@ def __init__(self, **kwargs): def __post_init__(self): # Database and schema are treated as the same thing - if ( - self.database is not None and - self.database != self.schema - ): + if self.database is not None and self.database != self.schema: raise dbt.exceptions.RuntimeException( f" schema: {self.schema} \n" f" database: {self.database} \n" @@ -76,8 +73,8 @@ class MariaDBConnectionManager(SQLConnectionManager): @classmethod def open(cls, connection): - if connection.state == 'open': - logger.debug('Connection is already open, skipping open.') + if connection.state == "open": + logger.debug("Connection is already open, skipping open.") return connection credentials = cls.get_credentials(connection.credentials) @@ -96,26 +93,29 @@ def open(cls, connection): try: connection.handle = mysql.connector.connect(**kwargs) - connection.state = 'open' + connection.state = "open" except mysql.connector.Error: try: - logger.debug("Failed connection without supplying the `database`. " - "Trying again with `database` included.") + logger.debug( + "Failed connection without supplying the `database`. " + "Trying again with `database` included." + ) # Try again with the database included kwargs["database"] = credentials.schema connection.handle = mysql.connector.connect(**kwargs) - connection.state = 'open' + connection.state = "open" except mysql.connector.Error as e: - logger.debug("Got an error when attempting to open a MariaDB " - "connection: '{}'" - .format(e)) + logger.debug( + "Got an error when attempting to open a MariaDB " + "connection: '{}'".format(e) + ) connection.handle = None - connection.state = 'fail' + connection.state = "fail" raise dbt.exceptions.FailedToConnectException(str(e)) @@ -134,7 +134,7 @@ def exception_handler(self, sql): yield except mysql.connector.DatabaseError as e: - logger.debug('MariaDB error: {}'.format(str(e))) + logger.debug("MariaDB error: {}".format(str(e))) try: self.rollback_if_open() @@ -167,7 +167,5 @@ def get_response(cls, cursor) -> AdapterResponse: # There's no real way to get the status from the mysql-connector-python driver. # So just return the default value. return AdapterResponse( - _message="{} {}".format(code, num_rows), - rows_affected=num_rows, - code=code + _message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code ) diff --git a/dbt/adapters/mariadb/impl.py b/dbt/adapters/mariadb/impl.py index 02d03be..6a16bcb 100644 --- a/dbt/adapters/mariadb/impl.py +++ b/dbt/adapters/mariadb/impl.py @@ -18,8 +18,8 @@ logger = AdapterLogger("mysql") -LIST_SCHEMAS_MACRO_NAME = 'list_schemas' -LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' +LIST_SCHEMAS_MACRO_NAME = "list_schemas" +LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" class MariaDBAdapter(SQLAdapter): @@ -29,28 +29,23 @@ class MariaDBAdapter(SQLAdapter): @classmethod def date_function(cls): - return 'current_date()' + return "current_date()" @classmethod - def convert_datetime_type( - cls, agate_table: agate.Table, col_idx: int - ) -> str: + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" def quote(self, identifier): - return '`{}`'.format(identifier) + return "`{}`".format(identifier) def list_relations_without_caching( self, schema_relation: MariaDBRelation ) -> List[MariaDBRelation]: - kwargs = {'schema_relation': schema_relation} + kwargs = {"schema_relation": schema_relation} try: - results = self.execute_macro( - LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs - ) + results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except dbt.exceptions.RuntimeException as e: - errmsg = getattr(e, 'msg', '') + errmsg = getattr(e, "msg", "") if f"MariaDB database '{schema_relation}' not found" in errmsg: return [] else: @@ -64,13 +59,11 @@ def list_relations_without_caching( raise dbt.exceptions.RuntimeException( "Invalid value from " f'"mariadb__list_relations_without_caching({kwargs})", ' - f'got {len(row)} values, expected 4' + f"got {len(row)} values, expected 4" ) _, name, _schema, relation_type = row relation = self.Relation.create( - schema=_schema, - identifier=name, - type=relation_type + schema=_schema, identifier=name, type=relation_type ) relations.append(relation) @@ -88,9 +81,9 @@ def _get_columns_for_catalog( for column in columns: # convert MariaDBColumns into catalog dicts as_dict = asdict(column) - as_dict['column_name'] = as_dict.pop('column', None) - as_dict['column_type'] = as_dict.pop('dtype') - as_dict['table_database'] = None + as_dict["column_name"] = as_dict.pop("column", None) + as_dict["column_type"] = as_dict.pop("dtype") + as_dict["table_database"] = None yield as_dict def get_relation( @@ -102,48 +95,58 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_show_columns( - self, - relation: Relation, - raw_rows: List[agate.Row] + self, relation: Relation, raw_rows: List[agate.Row] ) -> List[MariaDBColumn]: - return [MariaDBColumn( - table_database=None, - table_schema=relation.schema, - table_name=relation.name, - table_type=relation.type, - table_owner=None, - table_stats=None, - column=column.column, - column_index=idx, - dtype=column.dtype, - ) for idx, column in enumerate(raw_rows)] + return [ + MariaDBColumn( + table_database=None, + table_schema=relation.schema, + table_name=relation.name, + table_type=relation.type, + table_owner=None, + table_stats=None, + column=column.column, + column_index=idx, + dtype=column.dtype, + ) + for idx, column in enumerate(raw_rows) + ] def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f'Expected only one database in get_catalog, found ' - f'{list(schema_map)}' + f"Expected only one database in get_catalog, found " + f"{list(schema_map)}" ) with executor(self.config) as tpe: futures: List[Future[agate.Table]] = [] for info, schemas in schema_map.items(): for schema in schemas: - futures.append(tpe.submit_connected( - self, schema, - self._get_one_catalog, info, [schema], manifest - )) + futures.append( + tpe.submit_connected( + self, + schema, + self._get_one_catalog, + info, + [schema], + manifest, + ) + ) catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions def _get_one_catalog( - self, information_schema, schemas, manifest, + self, + information_schema, + schemas, + manifest, ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f'Expected only one schema in mariadb _get_one_catalog, found ' - f'{schemas}' + f"Expected only one schema in mariadb _get_one_catalog, found " + f"{schemas}" ) database = information_schema.database @@ -153,14 +156,11 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) columns.extend(self._get_columns_for_catalog(relation)) - return agate.Table.from_object( - columns, column_types=DEFAULT_TYPE_TESTER - ) + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database} + LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} ) exists = True if schema in [row[0] for row in results] else False @@ -174,13 +174,13 @@ def update_column_sql( clause: str, where_clause: Optional[str] = None, ) -> str: - clause = f'update {dst_name} set {dst_column} = {clause}' + clause = f"update {dst_name} set {dst_column} = {clause}" if where_clause is not None: - clause += f' where {where_clause}' + clause += f" where {where_clause}" return clause def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = 'hour' + self, add_to: str, number: int = 1, interval: str = "hour" ) -> str: # for backwards compatibility, we're compelled to set some sort of # default. A lot of searching has lead me to believe that the @@ -189,11 +189,14 @@ def timestamp_add_sql( return f"date_add({add_to}, interval {number} {interval})" def string_add_sql( - self, add_to: str, value: str, location='append', + self, + add_to: str, + value: str, + location="append", ) -> str: - if location == 'append': + if location == "append": return f"concat({add_to}, '{value}')" - elif location == 'prepend': + elif location == "prepend": return f"concat({value}, '{add_to}')" else: raise dbt.exceptions.RuntimeException( @@ -216,15 +219,15 @@ def get_rows_different_sql( alias_a = "A" alias_b = "B" - columns_csv_a = ', '.join([f"{alias_a}.{name}" for name in names]) - columns_csv_b = ', '.join([f"{alias_b}.{name}" for name in names]) + columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names]) + columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names]) join_condition = " AND ".join( [f"{alias_a}.{name} = {alias_b}.{name}" for name in names] ) first_column = names[0] # There is no EXCEPT or MINUS operator, so we need to simulate it - COLUMNS_EQUAL_SQL = ''' + COLUMNS_EQUAL_SQL = """ SELECT row_count_diff.difference as row_count_difference, diff_count.num_missing as num_mismatched @@ -259,7 +262,7 @@ def get_rows_different_sql( ) as missing ) as diff_count ON row_count_diff.id = diff_count.id - '''.strip() + """.strip() sql = COLUMNS_EQUAL_SQL.format( alias_a=alias_a, diff --git a/dbt/adapters/mariadb/relation.py b/dbt/adapters/mariadb/relation.py index 2f498f1..0b21aa0 100644 --- a/dbt/adapters/mariadb/relation.py +++ b/dbt/adapters/mariadb/relation.py @@ -22,7 +22,7 @@ class MariaDBIncludePolicy(Policy): class MariaDBRelation(BaseRelation): quote_policy: MariaDBQuotePolicy = MariaDBQuotePolicy() include_policy: MariaDBIncludePolicy = MariaDBIncludePolicy() - quote_character: str = '`' + quote_character: str = "`" def __post_init__(self): if self.database != self.schema and self.database: diff --git a/dbt/adapters/mysql/__init__.py b/dbt/adapters/mysql/__init__.py index 9d2bbac..654b023 100644 --- a/dbt/adapters/mysql/__init__.py +++ b/dbt/adapters/mysql/__init__.py @@ -9,6 +9,5 @@ Plugin = AdapterPlugin( - adapter=MySQLAdapter, - credentials=MySQLCredentials, - include_path=mysql.PACKAGE_PATH) + adapter=MySQLAdapter, credentials=MySQLCredentials, include_path=mysql.PACKAGE_PATH +) diff --git a/dbt/adapters/mysql/column.py b/dbt/adapters/mysql/column.py index 3c27dbc..9ce3786 100644 --- a/dbt/adapters/mysql/column.py +++ b/dbt/adapters/mysql/column.py @@ -3,7 +3,7 @@ from dbt.adapters.base.column import Column -Self = TypeVar('Self', bound='MySQLColumn') +Self = TypeVar("Self", bound="MySQLColumn") @dataclass @@ -18,7 +18,7 @@ class MySQLColumn(Column): @property def quoted(self) -> str: - return '`{}`'.format(self.column) + return "`{}`".format(self.column) def __repr__(self) -> str: return "".format(self.name, self.data_type) diff --git a/dbt/adapters/mysql/connections.py b/dbt/adapters/mysql/connections.py index f85ba6b..6a4e285 100644 --- a/dbt/adapters/mysql/connections.py +++ b/dbt/adapters/mysql/connections.py @@ -38,10 +38,7 @@ def __init__(self, **kwargs): def __post_init__(self): # mysql classifies database and schema as the same thing - if ( - self.database is not None and - self.database != self.schema - ): + if self.database is not None and self.database != self.schema: raise dbt.exceptions.RuntimeException( f" schema: {self.schema} \n" f" database: {self.database} \n" @@ -75,8 +72,8 @@ class MySQLConnectionManager(SQLConnectionManager): @classmethod def open(cls, connection): - if connection.state == 'open': - logger.debug('Connection is already open, skipping open.') + if connection.state == "open": + logger.debug("Connection is already open, skipping open.") return connection credentials = cls.get_credentials(connection.credentials) @@ -92,26 +89,29 @@ def open(cls, connection): try: connection.handle = mysql.connector.connect(**kwargs) - connection.state = 'open' + connection.state = "open" except mysql.connector.Error: try: - logger.debug("Failed connection without supplying the `database`. " - "Trying again with `database` included.") + logger.debug( + "Failed connection without supplying the `database`. " + "Trying again with `database` included." + ) # Try again with the database included kwargs["database"] = credentials.schema connection.handle = mysql.connector.connect(**kwargs) - connection.state = 'open' + connection.state = "open" except mysql.connector.Error as e: - logger.debug("Got an error when attempting to open a mysql " - "connection: '{}'" - .format(e)) + logger.debug( + "Got an error when attempting to open a mysql " + "connection: '{}'".format(e) + ) connection.handle = None - connection.state = 'fail' + connection.state = "fail" raise dbt.exceptions.FailedToConnectException(str(e)) @@ -130,7 +130,7 @@ def exception_handler(self, sql): yield except mysql.connector.DatabaseError as e: - logger.debug('MySQL error: {}'.format(str(e))) + logger.debug("MySQL error: {}".format(str(e))) try: self.rollback_if_open() @@ -163,7 +163,5 @@ def get_response(cls, cursor) -> AdapterResponse: # There's no real way to get the status from the mysql-connector-python driver. # So just return the default value. return AdapterResponse( - _message="{} {}".format(code, num_rows), - rows_affected=num_rows, - code=code + _message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code ) diff --git a/dbt/adapters/mysql/impl.py b/dbt/adapters/mysql/impl.py index 73e6159..f1e11d1 100644 --- a/dbt/adapters/mysql/impl.py +++ b/dbt/adapters/mysql/impl.py @@ -18,8 +18,8 @@ logger = AdapterLogger("mysql") -LIST_SCHEMAS_MACRO_NAME = 'list_schemas' -LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' +LIST_SCHEMAS_MACRO_NAME = "list_schemas" +LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" class MySQLAdapter(SQLAdapter): @@ -29,28 +29,23 @@ class MySQLAdapter(SQLAdapter): @classmethod def date_function(cls): - return 'current_date()' + return "current_date()" @classmethod - def convert_datetime_type( - cls, agate_table: agate.Table, col_idx: int - ) -> str: + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" def quote(self, identifier): - return '`{}`'.format(identifier) + return "`{}`".format(identifier) def list_relations_without_caching( self, schema_relation: MySQLRelation ) -> List[MySQLRelation]: - kwargs = {'schema_relation': schema_relation} + kwargs = {"schema_relation": schema_relation} try: - results = self.execute_macro( - LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs - ) + results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except dbt.exceptions.RuntimeException as e: - errmsg = getattr(e, 'msg', '') + errmsg = getattr(e, "msg", "") if f"MySQL database '{schema_relation}' not found" in errmsg: return [] else: @@ -64,13 +59,11 @@ def list_relations_without_caching( raise dbt.exceptions.RuntimeException( "Invalid value from " f'"mysql__list_relations_without_caching({kwargs})", ' - f'got {len(row)} values, expected 4' + f"got {len(row)} values, expected 4" ) _, name, _schema, relation_type = row relation = self.Relation.create( - schema=_schema, - identifier=name, - type=relation_type + schema=_schema, identifier=name, type=relation_type ) relations.append(relation) @@ -88,9 +81,9 @@ def _get_columns_for_catalog( for column in columns: # convert MySQLColumns into catalog dicts as_dict = asdict(column) - as_dict['column_name'] = as_dict.pop('column', None) - as_dict['column_type'] = as_dict.pop('dtype') - as_dict['table_database'] = None + as_dict["column_name"] = as_dict.pop("column", None) + as_dict["column_type"] = as_dict.pop("dtype") + as_dict["table_database"] = None yield as_dict def get_relation( @@ -102,49 +95,59 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_show_columns( - self, - relation: Relation, - raw_rows: List[agate.Row] + self, relation: Relation, raw_rows: List[agate.Row] ) -> List[MySQLColumn]: - return [MySQLColumn( - table_database=None, - table_schema=relation.schema, - table_name=relation.name, - table_type=relation.type, - table_owner=None, - table_stats=None, - column=column.column, - column_index=idx, - dtype=column.dtype, - ) for idx, column in enumerate(raw_rows)] + return [ + MySQLColumn( + table_database=None, + table_schema=relation.schema, + table_name=relation.name, + table_type=relation.type, + table_owner=None, + table_stats=None, + column=column.column, + column_index=idx, + dtype=column.dtype, + ) + for idx, column in enumerate(raw_rows) + ] def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f'Expected only one database in get_catalog, found ' - f'{list(schema_map)}' + f"Expected only one database in get_catalog, found " + f"{list(schema_map)}" ) with executor(self.config) as tpe: futures: List[Future[agate.Table]] = [] for info, schemas in schema_map.items(): for schema in schemas: - futures.append(tpe.submit_connected( - self, schema, - self._get_one_catalog, info, [schema], manifest - )) + futures.append( + tpe.submit_connected( + self, + schema, + self._get_one_catalog, + info, + [schema], + manifest, + ) + ) catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions def _get_one_catalog( - self, information_schema, schemas, manifest, + self, + information_schema, + schemas, + manifest, ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f'Expected only one schema in mysql _get_one_catalog, found ' - f'{schemas}' + f"Expected only one schema in mysql _get_one_catalog, found " + f"{schemas}" ) database = information_schema.database @@ -154,14 +157,11 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) columns.extend(self._get_columns_for_catalog(relation)) - return agate.Table.from_object( - columns, column_types=DEFAULT_TYPE_TESTER - ) + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database} + LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} ) exists = True if schema in [row[0] for row in results] else False @@ -175,13 +175,13 @@ def update_column_sql( clause: str, where_clause: Optional[str] = None, ) -> str: - clause = f'update {dst_name} set {dst_column} = {clause}' + clause = f"update {dst_name} set {dst_column} = {clause}" if where_clause is not None: - clause += f' where {where_clause}' + clause += f" where {where_clause}" return clause def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = 'hour' + self, add_to: str, number: int = 1, interval: str = "hour" ) -> str: # for backwards compatibility, we're compelled to set some sort of # default. A lot of searching has lead me to believe that the @@ -190,11 +190,14 @@ def timestamp_add_sql( return f"date_add({add_to}, interval {number} {interval})" def string_add_sql( - self, add_to: str, value: str, location='append', + self, + add_to: str, + value: str, + location="append", ) -> str: - if location == 'append': + if location == "append": return f"concat({add_to}, '{value}')" - elif location == 'prepend': + elif location == "prepend": return f"concat({value}, '{add_to}')" else: raise dbt.exceptions.RuntimeException( @@ -217,15 +220,15 @@ def get_rows_different_sql( alias_a = "A" alias_b = "B" - columns_csv_a = ', '.join([f"{alias_a}.{name}" for name in names]) - columns_csv_b = ', '.join([f"{alias_b}.{name}" for name in names]) + columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names]) + columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names]) join_condition = " AND ".join( [f"{alias_a}.{name} = {alias_b}.{name}" for name in names] ) first_column = names[0] # MySQL doesn't have an EXCEPT or MINUS operator, so we need to simulate it - COLUMNS_EQUAL_SQL = ''' + COLUMNS_EQUAL_SQL = """ WITH a_except_b as ( SELECT @@ -269,7 +272,7 @@ def get_rows_different_sql( diff_count.num_missing as num_mismatched FROM row_count_diff INNER JOIN diff_count ON row_count_diff.id = diff_count.id - '''.strip() + """.strip() sql = COLUMNS_EQUAL_SQL.format( alias_a=alias_a, diff --git a/dbt/adapters/mysql/relation.py b/dbt/adapters/mysql/relation.py index 21cec97..859afc1 100644 --- a/dbt/adapters/mysql/relation.py +++ b/dbt/adapters/mysql/relation.py @@ -22,7 +22,7 @@ class MySQLIncludePolicy(Policy): class MySQLRelation(BaseRelation): quote_policy: MySQLQuotePolicy = MySQLQuotePolicy() include_policy: MySQLIncludePolicy = MySQLIncludePolicy() - quote_character: str = '`' + quote_character: str = "`" def __post_init__(self): if self.database != self.schema and self.database: diff --git a/dbt/adapters/mysql5/__init__.py b/dbt/adapters/mysql5/__init__.py index c0e2c84..8f23e58 100644 --- a/dbt/adapters/mysql5/__init__.py +++ b/dbt/adapters/mysql5/__init__.py @@ -9,6 +9,5 @@ Plugin = AdapterPlugin( - adapter=MySQLAdapter, - credentials=MySQLCredentials, - include_path=mysql5.PACKAGE_PATH) + adapter=MySQLAdapter, credentials=MySQLCredentials, include_path=mysql5.PACKAGE_PATH +) diff --git a/dbt/adapters/mysql5/column.py b/dbt/adapters/mysql5/column.py index 3c27dbc..9ce3786 100644 --- a/dbt/adapters/mysql5/column.py +++ b/dbt/adapters/mysql5/column.py @@ -3,7 +3,7 @@ from dbt.adapters.base.column import Column -Self = TypeVar('Self', bound='MySQLColumn') +Self = TypeVar("Self", bound="MySQLColumn") @dataclass @@ -18,7 +18,7 @@ class MySQLColumn(Column): @property def quoted(self) -> str: - return '`{}`'.format(self.column) + return "`{}`".format(self.column) def __repr__(self) -> str: return "".format(self.name, self.data_type) diff --git a/dbt/adapters/mysql5/connections.py b/dbt/adapters/mysql5/connections.py index 9086e31..c8c1d20 100644 --- a/dbt/adapters/mysql5/connections.py +++ b/dbt/adapters/mysql5/connections.py @@ -39,10 +39,7 @@ def __init__(self, **kwargs): def __post_init__(self): # mysql classifies database and schema as the same thing - if ( - self.database is not None and - self.database != self.schema - ): + if self.database is not None and self.database != self.schema: raise dbt.exceptions.RuntimeException( f" schema: {self.schema} \n" f" database: {self.database} \n" @@ -76,8 +73,8 @@ class MySQLConnectionManager(SQLConnectionManager): @classmethod def open(cls, connection): - if connection.state == 'open': - logger.debug('Connection is already open, skipping open.') + if connection.state == "open": + logger.debug("Connection is already open, skipping open.") return connection credentials = cls.get_credentials(connection.credentials) @@ -96,26 +93,29 @@ def open(cls, connection): try: connection.handle = mysql.connector.connect(**kwargs) - connection.state = 'open' + connection.state = "open" except mysql.connector.Error: try: - logger.debug("Failed connection without supplying the `database`. " - "Trying again with `database` included.") + logger.debug( + "Failed connection without supplying the `database`. " + "Trying again with `database` included." + ) # Try again with the database included kwargs["database"] = credentials.schema connection.handle = mysql.connector.connect(**kwargs) - connection.state = 'open' + connection.state = "open" except mysql.connector.Error as e: - logger.debug("Got an error when attempting to open a mysql " - "connection: '{}'" - .format(e)) + logger.debug( + "Got an error when attempting to open a mysql " + "connection: '{}'".format(e) + ) connection.handle = None - connection.state = 'fail' + connection.state = "fail" raise dbt.exceptions.FailedToConnectException(str(e)) @@ -134,7 +134,7 @@ def exception_handler(self, sql): yield except mysql.connector.DatabaseError as e: - logger.debug('MySQL error: {}'.format(str(e))) + logger.debug("MySQL error: {}".format(str(e))) try: self.rollback_if_open() @@ -167,7 +167,5 @@ def get_response(cls, cursor) -> AdapterResponse: # There's no real way to get the status from the mysql-connector-python driver. # So just return the default value. return AdapterResponse( - _message="{} {}".format(code, num_rows), - rows_affected=num_rows, - code=code + _message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code ) diff --git a/dbt/adapters/mysql5/impl.py b/dbt/adapters/mysql5/impl.py index df5cc3a..2582c83 100644 --- a/dbt/adapters/mysql5/impl.py +++ b/dbt/adapters/mysql5/impl.py @@ -18,8 +18,8 @@ logger = AdapterLogger("mysql") -LIST_SCHEMAS_MACRO_NAME = 'list_schemas' -LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' +LIST_SCHEMAS_MACRO_NAME = "list_schemas" +LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" class MySQLAdapter(SQLAdapter): @@ -29,28 +29,23 @@ class MySQLAdapter(SQLAdapter): @classmethod def date_function(cls): - return 'current_date()' + return "current_date()" @classmethod - def convert_datetime_type( - cls, agate_table: agate.Table, col_idx: int - ) -> str: + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" def quote(self, identifier): - return '`{}`'.format(identifier) + return "`{}`".format(identifier) def list_relations_without_caching( self, schema_relation: MySQLRelation ) -> List[MySQLRelation]: - kwargs = {'schema_relation': schema_relation} + kwargs = {"schema_relation": schema_relation} try: - results = self.execute_macro( - LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs - ) + results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except dbt.exceptions.RuntimeException as e: - errmsg = getattr(e, 'msg', '') + errmsg = getattr(e, "msg", "") if f"MySQL database '{schema_relation}' not found" in errmsg: return [] else: @@ -64,13 +59,11 @@ def list_relations_without_caching( raise dbt.exceptions.RuntimeException( "Invalid value from " f'"mysql5__list_relations_without_caching({kwargs})", ' - f'got {len(row)} values, expected 4' + f"got {len(row)} values, expected 4" ) _, name, _schema, relation_type = row relation = self.Relation.create( - schema=_schema, - identifier=name, - type=relation_type + schema=_schema, identifier=name, type=relation_type ) relations.append(relation) @@ -88,9 +81,9 @@ def _get_columns_for_catalog( for column in columns: # convert MySQLColumns into catalog dicts as_dict = asdict(column) - as_dict['column_name'] = as_dict.pop('column', None) - as_dict['column_type'] = as_dict.pop('dtype') - as_dict['table_database'] = None + as_dict["column_name"] = as_dict.pop("column", None) + as_dict["column_type"] = as_dict.pop("dtype") + as_dict["table_database"] = None yield as_dict def get_relation( @@ -102,48 +95,58 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_show_columns( - self, - relation: Relation, - raw_rows: List[agate.Row] + self, relation: Relation, raw_rows: List[agate.Row] ) -> List[MySQLColumn]: - return [MySQLColumn( - table_database=None, - table_schema=relation.schema, - table_name=relation.name, - table_type=relation.type, - table_owner=None, - table_stats=None, - column=column.column, - column_index=idx, - dtype=column.dtype, - ) for idx, column in enumerate(raw_rows)] + return [ + MySQLColumn( + table_database=None, + table_schema=relation.schema, + table_name=relation.name, + table_type=relation.type, + table_owner=None, + table_stats=None, + column=column.column, + column_index=idx, + dtype=column.dtype, + ) + for idx, column in enumerate(raw_rows) + ] def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f'Expected only one database in get_catalog, found ' - f'{list(schema_map)}' + f"Expected only one database in get_catalog, found " + f"{list(schema_map)}" ) with executor(self.config) as tpe: futures: List[Future[agate.Table]] = [] for info, schemas in schema_map.items(): for schema in schemas: - futures.append(tpe.submit_connected( - self, schema, - self._get_one_catalog, info, [schema], manifest - )) + futures.append( + tpe.submit_connected( + self, + schema, + self._get_one_catalog, + info, + [schema], + manifest, + ) + ) catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions def _get_one_catalog( - self, information_schema, schemas, manifest, + self, + information_schema, + schemas, + manifest, ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f'Expected only one schema in mysql5 _get_one_catalog, found ' - f'{schemas}' + f"Expected only one schema in mysql5 _get_one_catalog, found " + f"{schemas}" ) database = information_schema.database @@ -153,14 +156,11 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) columns.extend(self._get_columns_for_catalog(relation)) - return agate.Table.from_object( - columns, column_types=DEFAULT_TYPE_TESTER - ) + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database} + LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} ) exists = True if schema in [row[0] for row in results] else False @@ -174,13 +174,13 @@ def update_column_sql( clause: str, where_clause: Optional[str] = None, ) -> str: - clause = f'update {dst_name} set {dst_column} = {clause}' + clause = f"update {dst_name} set {dst_column} = {clause}" if where_clause is not None: - clause += f' where {where_clause}' + clause += f" where {where_clause}" return clause def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = 'hour' + self, add_to: str, number: int = 1, interval: str = "hour" ) -> str: # for backwards compatibility, we're compelled to set some sort of # default. A lot of searching has lead me to believe that the @@ -189,11 +189,14 @@ def timestamp_add_sql( return f"date_add({add_to}, interval {number} {interval})" def string_add_sql( - self, add_to: str, value: str, location='append', + self, + add_to: str, + value: str, + location="append", ) -> str: - if location == 'append': + if location == "append": return f"concat({add_to}, '{value}')" - elif location == 'prepend': + elif location == "prepend": return f"concat({value}, '{add_to}')" else: raise dbt.exceptions.RuntimeException( @@ -216,15 +219,15 @@ def get_rows_different_sql( alias_a = "A" alias_b = "B" - columns_csv_a = ', '.join([f"{alias_a}.{name}" for name in names]) - columns_csv_b = ', '.join([f"{alias_b}.{name}" for name in names]) + columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names]) + columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names]) join_condition = " AND ".join( [f"{alias_a}.{name} = {alias_b}.{name}" for name in names] ) first_column = names[0] # MySQL doesn't have an EXCEPT or MINUS operator, so we need to simulate it - COLUMNS_EQUAL_SQL = ''' + COLUMNS_EQUAL_SQL = """ SELECT row_count_diff.difference as row_count_difference, diff_count.num_missing as num_mismatched @@ -259,7 +262,7 @@ def get_rows_different_sql( ) as missing ) as diff_count ON row_count_diff.id = diff_count.id - '''.strip() + """.strip() sql = COLUMNS_EQUAL_SQL.format( alias_a=alias_a, diff --git a/dbt/adapters/mysql5/relation.py b/dbt/adapters/mysql5/relation.py index 2c781d5..1b03317 100644 --- a/dbt/adapters/mysql5/relation.py +++ b/dbt/adapters/mysql5/relation.py @@ -22,7 +22,7 @@ class MySQLIncludePolicy(Policy): class MySQLRelation(BaseRelation): quote_policy: MySQLQuotePolicy = MySQLQuotePolicy() include_policy: MySQLIncludePolicy = MySQLIncludePolicy() - quote_character: str = '`' + quote_character: str = "`" def __post_init__(self): if self.database != self.schema and self.database: diff --git a/dbt/include/mariadb/__init__.py b/dbt/include/mariadb/__init__.py index 564a3d1..b177e5d 100644 --- a/dbt/include/mariadb/__init__.py +++ b/dbt/include/mariadb/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/mariadb/macros/materializations/test/test.sql b/dbt/include/mariadb/macros/materializations/test/test.sql index 38dcf4e..843aa81 100644 --- a/dbt/include/mariadb/macros/materializations/test/test.sql +++ b/dbt/include/mariadb/macros/materializations/test/test.sql @@ -13,4 +13,4 @@ {{ main_sql }} {{ "limit " ~ limit if limit != none }} ) dbt_internal_test -{%- endmacro %} \ No newline at end of file +{%- endmacro %} diff --git a/dbt/include/mysql/__init__.py b/dbt/include/mysql/__init__.py index 564a3d1..b177e5d 100644 --- a/dbt/include/mysql/__init__.py +++ b/dbt/include/mysql/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/mysql/macros/materializations/test/test.sql b/dbt/include/mysql/macros/materializations/test/test.sql index 92c16af..3b8c05a 100644 --- a/dbt/include/mysql/macros/materializations/test/test.sql +++ b/dbt/include/mysql/macros/materializations/test/test.sql @@ -13,4 +13,4 @@ {{ main_sql }} {{ "limit " ~ limit if limit != none }} ) dbt_internal_test -{%- endmacro %} \ No newline at end of file +{%- endmacro %} diff --git a/dbt/include/mysql5/__init__.py b/dbt/include/mysql5/__init__.py index 564a3d1..b177e5d 100644 --- a/dbt/include/mysql5/__init__.py +++ b/dbt/include/mysql5/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/mysql5/macros/materializations/test/test.sql b/dbt/include/mysql5/macros/materializations/test/test.sql index af7036a..baabc04 100644 --- a/dbt/include/mysql5/macros/materializations/test/test.sql +++ b/dbt/include/mysql5/macros/materializations/test/test.sql @@ -13,4 +13,4 @@ {{ main_sql }} {{ "limit " ~ limit if limit != none }} ) dbt_internal_test -{%- endmacro %} \ No newline at end of file +{%- endmacro %} diff --git a/setup.py b/setup.py index 8e9cea1..1441b2c 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,9 @@ from setuptools import find_namespace_packages except ImportError: print("Error: dbt requires setuptools v40.1.0 or higher.") - print('Please upgrade setuptools with "pip install --upgrade setuptools" and try again') + print( + 'Please upgrade setuptools with "pip install --upgrade setuptools" and try again' + ) sys.exit(1) @@ -86,14 +88,14 @@ def _core_version(plugin_version: str = _plugin_version()) -> str: ], zip_safe=False, classifiers=[ - 'Development Status :: 3 - Alpha', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], python_requires=">=3.7", ) diff --git a/tests/conftest.py b/tests/conftest.py index 20118dd..6cb3084 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,33 +37,33 @@ def dbt_profile_target(request): # dbt will supply a unique schema per test, so we do not specify 'schema' here def mysql_target(): return { - 'type': 'mysql', - 'port': int(os.getenv('DBT_MYSQL_80_PORT', '3306')), - 'server': os.getenv('DBT_MYSQL_SERVER_NAME', 'localhost'), - 'username': os.getenv('DBT_MYSQL_USERNAME', 'root'), - 'password': os.getenv('DBT_MYSQL_PASSWORD', 'dbt'), + "type": "mysql", + "port": int(os.getenv("DBT_MYSQL_80_PORT", "3306")), + "server": os.getenv("DBT_MYSQL_SERVER_NAME", "localhost"), + "username": os.getenv("DBT_MYSQL_USERNAME", "root"), + "password": os.getenv("DBT_MYSQL_PASSWORD", "dbt"), } # dbt will supply a unique schema per test, so we do not specify 'schema' here def mysql5_target(): return { - 'type': 'mysql5', - 'port': int(os.getenv('DBT_MYSQL_57_PORT', '3306')), - 'server': os.getenv('DBT_MYSQL_SERVER_NAME', 'localhost'), - 'username': os.getenv('DBT_MYSQL_USERNAME', 'root'), - 'password': os.getenv('DBT_MYSQL_PASSWORD', 'dbt'), + "type": "mysql5", + "port": int(os.getenv("DBT_MYSQL_57_PORT", "3306")), + "server": os.getenv("DBT_MYSQL_SERVER_NAME", "localhost"), + "username": os.getenv("DBT_MYSQL_USERNAME", "root"), + "password": os.getenv("DBT_MYSQL_PASSWORD", "dbt"), } # dbt will supply a unique schema per test, so we do not specify 'schema' here def mariadb_target(): return { - 'type': 'mariadb', - 'port': int(os.getenv('DBT_MARIADB_105_PORT', '3306')), - 'server': os.getenv('DBT_MYSQL_SERVER_NAME', 'localhost'), - 'username': os.getenv('DBT_MYSQL_USERNAME', 'root'), - 'password': os.getenv('DBT_MYSQL_PASSWORD', 'dbt'), + "type": "mariadb", + "port": int(os.getenv("DBT_MARIADB_105_PORT", "3306")), + "server": os.getenv("DBT_MYSQL_SERVER_NAME", "localhost"), + "username": os.getenv("DBT_MYSQL_USERNAME", "root"), + "password": os.getenv("DBT_MYSQL_PASSWORD", "dbt"), } diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index 0360b59..8505b1c 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -28,7 +28,7 @@ class TestSingularTestsMySQL(BaseSingularTests): # Ephemeral materializations not supported for MySQL 5.7 -@pytest.mark.skip_profile('mysql5') +@pytest.mark.skip_profile("mysql5") class TestSingularTestsEphemeralMySQL(BaseSingularTestsEphemeral): pass @@ -38,7 +38,7 @@ class TestEmptyMySQL(BaseEmpty): # Ephemeral materializations not supported for MySQL 5.7 -@pytest.mark.skip_profile('mysql5') +@pytest.mark.skip_profile("mysql5") class TestEphemeralMySQL(BaseEphemeral): pass diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index a242a80..8c499d5 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -7,35 +7,34 @@ class TestMySQLAdapter(unittest.TestCase): - def setUp(self): pass flags.STRICT_MODE = True profile_cfg = { - 'outputs': { - 'test': { - 'type': 'mysql', - 'server': 'thishostshouldnotexist', - 'port': 3306, - 'schema': 'dbt_test_schema', - 'username': 'dbt', - 'password': 'dbt', + "outputs": { + "test": { + "type": "mysql", + "server": "thishostshouldnotexist", + "port": 3306, + "schema": "dbt_test_schema", + "username": "dbt", + "password": "dbt", } }, - 'target': 'test' + "target": "test", } project_cfg = { - 'name': 'X', - 'version': '0.1', - 'profile': 'test', - 'project-root': '/tmp/dbt/does-not-exist', - 'quoting': { - 'identifier': False, - 'schema': True, + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "quoting": { + "identifier": False, + "schema": True, }, - 'config-version': 2 + "config-version": 2, } self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) @@ -47,13 +46,13 @@ def adapter(self): self._adapter = MySQLAdapter(self.config) return self._adapter - @mock.patch('dbt.adapters.mysql.connections.mysql.connector') + @mock.patch("dbt.adapters.mysql.connections.mysql.connector") def test_acquire_connection(self, connector): - connection = self.adapter.acquire_connection('dummy') + connection = self.adapter.acquire_connection("dummy") connector.connect.assert_not_called() connection.handle - self.assertEqual(connection.state, 'open') + self.assertEqual(connection.state, "open") self.assertNotEqual(connection.handle, None) connector.connect.assert_called_once() @@ -62,8 +61,7 @@ def test_cancel_open_connections_empty(self): def test_cancel_open_connections_main(self): key = self.adapter.connections.get_thread_identifier() - self.adapter.connections.thread_connections[key] = mock_connection( - 'main') + self.adapter.connections.thread_connections[key] = mock_connection("main") self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_placeholder(self): diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 32db0ef..07371d3 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -25,7 +25,7 @@ def normalize(path): class Obj: - which = 'blah' + which = "blah" single_threaded = False @@ -35,11 +35,11 @@ def mock_connection(name): return conn -def profile_from_dict(profile, profile_name, cli_vars='{}'): +def profile_from_dict(profile, profile_name, cli_vars="{}"): from dbt.config import Profile from dbt.config.renderer import ProfileRenderer - from dbt.context.base import generate_base_context from dbt.config.utils import parse_cli_vars + if not isinstance(cli_vars, dict): cli_vars = parse_cli_vars(cli_vars) @@ -51,17 +51,16 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'): ) -def project_from_dict(project, profile, packages=None, selectors=None, cli_vars='{}'): - from dbt.context.target import generate_target_context - from dbt.config import Project +def project_from_dict(project, profile, packages=None, selectors=None, cli_vars="{}"): from dbt.config.renderer import DbtProjectYamlRenderer from dbt.config.utils import parse_cli_vars + if not isinstance(cli_vars, dict): cli_vars = parse_cli_vars(cli_vars) renderer = DbtProjectYamlRenderer(profile, cli_vars) - project_root = project.pop('project-root', os.getcwd()) + project_root = project.pop("project-root", os.getcwd()) partial = PartialProject.from_dicts( project_root=project_root, @@ -72,14 +71,16 @@ def project_from_dict(project, profile, packages=None, selectors=None, cli_vars= return partial.render(renderer) -def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars='{}'): +def config_from_parts_or_dicts( + project, profile, packages=None, selectors=None, cli_vars="{}" +): from dbt.config import Project, Profile, RuntimeConfig from copy import deepcopy if isinstance(project, Project): profile_name = project.profile_name else: - profile_name = project.get('profile') + profile_name = project.get("profile") if not isinstance(profile, Profile): profile = profile_from_dict( @@ -99,16 +100,13 @@ def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, args = Obj() args.vars = cli_vars - args.profile_dir = '/dev/null' - return RuntimeConfig.from_parts( - project=project, - profile=profile, - args=args - ) + args.profile_dir = "/dev/null" + return RuntimeConfig.from_parts(project=project, profile=profile, args=args) def inject_plugin(plugin): from dbt.adapters.factory import FACTORY + key = plugin.adapter.type() FACTORY.plugins[key] = plugin @@ -119,6 +117,7 @@ def inject_adapter(value, plugin): """ inject_plugin(plugin) from dbt.adapters.factory import FACTORY + key = value.type() FACTORY.adapters[key] = value @@ -136,7 +135,7 @@ def assert_to_dict(self, obj, dct): def assert_from_dict(self, obj, dct, cls=None): if cls is None: cls = self.ContractType - self.assertEqual(cls.from_dict(dct), obj) + self.assertEqual(cls.from_dict(dct), obj) def assert_symmetric(self, obj, dct, cls=None): self.assert_to_dict(obj, dct) @@ -153,26 +152,27 @@ def assert_fails_validation(self, dct, cls=None): def generate_name_macros(package): from dbt.contracts.graph.parsed import ParsedMacro from dbt.node_types import NodeType + name_sql = {} - for component in ('database', 'schema', 'alias'): - if component == 'alias': - source = 'node.name' + for component in ("database", "schema", "alias"): + if component == "alias": + source = "node.name" else: - source = f'target.{component}' - name = f'generate_{component}_name' - sql = f'{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}' + source = f"target.{component}" + name = f"generate_{component}_name" + sql = f"{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}" name_sql[name] = sql - all_sql = '\n'.join(name_sql.values()) + all_sql = "\n".join(name_sql.values()) for name, sql in name_sql.items(): pm = ParsedMacro( name=name, resource_type=NodeType.Macro, - unique_id=f'macro.{package}.{name}', + unique_id=f"macro.{package}.{name}", package_name=package, - original_file_path=normalize('macros/macro.sql'), - root_path='./dbt_modules/root', - path=normalize('macros/macro.sql'), + original_file_path=normalize("macros/macro.sql"), + root_path="./dbt_modules/root", + path=normalize("macros/macro.sql"), raw_sql=all_sql, macro_sql=sql, )