Skip to content

Commit

Permalink
Apply code quality formatting rules (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbeatty10 authored Feb 6, 2023
1 parent c1355a7 commit 509440e
Show file tree
Hide file tree
Showing 28 changed files with 331 additions and 327 deletions.
5 changes: 2 additions & 3 deletions .bumpversion.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ parse = (?P<major>\d+)
((?P<prerelease>[a-z]+)
?(\.)?
(?P<num>\d+))?
serialize =
serialize =
{major}.{minor}.{patch}{prerelease}{num}
{major}.{minor}.{patch}

[bumpversion:part:prerelease]
first_value = a
values =
values =
a
b
rc
Expand All @@ -33,4 +33,3 @@ replace = version = "{new_version}"
[bumpversion:file:dbt/adapters/mariadb/__version__.py]
search = version = "{current_version}"
replace = version = "{new_version}"

2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1 @@
recursive-include dbt/include *.sql *.yml *.md
recursive-include dbt/include *.sql *.yml *.md
3 changes: 2 additions & 1 deletion dbt/adapters/mariadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
Plugin = AdapterPlugin(
adapter=MariaDBAdapter,
credentials=MariaDBCredentials,
include_path=mariadb.PACKAGE_PATH)
include_path=mariadb.PACKAGE_PATH,
)
4 changes: 2 additions & 2 deletions dbt/adapters/mariadb/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dbt.adapters.base.column import Column

Self = TypeVar('Self', bound='MariaDBColumn')
Self = TypeVar("Self", bound="MariaDBColumn")


@dataclass
Expand All @@ -18,7 +18,7 @@ class MariaDBColumn(Column):

@property
def quoted(self) -> str:
return '`{}`'.format(self.column)
return "`{}`".format(self.column)

def __repr__(self) -> str:
return "<MariaDBColumn {} ({})>".format(self.name, self.data_type)
34 changes: 16 additions & 18 deletions dbt/adapters/mariadb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ def __init__(self, **kwargs):

def __post_init__(self):
# Database and schema are treated as the same thing
if (
self.database is not None and
self.database != self.schema
):
if self.database is not None and self.database != self.schema:
raise dbt.exceptions.RuntimeException(
f" schema: {self.schema} \n"
f" database: {self.database} \n"
Expand Down Expand Up @@ -76,8 +73,8 @@ class MariaDBConnectionManager(SQLConnectionManager):

@classmethod
def open(cls, connection):
if connection.state == 'open':
logger.debug('Connection is already open, skipping open.')
if connection.state == "open":
logger.debug("Connection is already open, skipping open.")
return connection

credentials = cls.get_credentials(connection.credentials)
Expand All @@ -96,26 +93,29 @@ def open(cls, connection):

try:
connection.handle = mysql.connector.connect(**kwargs)
connection.state = 'open'
connection.state = "open"
except mysql.connector.Error:

try:
logger.debug("Failed connection without supplying the `database`. "
"Trying again with `database` included.")
logger.debug(
"Failed connection without supplying the `database`. "
"Trying again with `database` included."
)

# Try again with the database included
kwargs["database"] = credentials.schema

connection.handle = mysql.connector.connect(**kwargs)
connection.state = 'open'
connection.state = "open"
except mysql.connector.Error as e:

logger.debug("Got an error when attempting to open a MariaDB "
"connection: '{}'"
.format(e))
logger.debug(
"Got an error when attempting to open a MariaDB "
"connection: '{}'".format(e)
)

connection.handle = None
connection.state = 'fail'
connection.state = "fail"

raise dbt.exceptions.FailedToConnectException(str(e))

Expand All @@ -134,7 +134,7 @@ def exception_handler(self, sql):
yield

except mysql.connector.DatabaseError as e:
logger.debug('MariaDB error: {}'.format(str(e)))
logger.debug("MariaDB error: {}".format(str(e)))

try:
self.rollback_if_open()
Expand Down Expand Up @@ -167,7 +167,5 @@ def get_response(cls, cursor) -> AdapterResponse:
# There's no real way to get the status from the mysql-connector-python driver.
# So just return the default value.
return AdapterResponse(
_message="{} {}".format(code, num_rows),
rows_affected=num_rows,
code=code
_message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code
)
119 changes: 61 additions & 58 deletions dbt/adapters/mariadb/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

logger = AdapterLogger("mysql")

LIST_SCHEMAS_MACRO_NAME = 'list_schemas'
LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
LIST_SCHEMAS_MACRO_NAME = "list_schemas"
LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching"


class MariaDBAdapter(SQLAdapter):
Expand All @@ -29,28 +29,23 @@ class MariaDBAdapter(SQLAdapter):

@classmethod
def date_function(cls):
return 'current_date()'
return "current_date()"

@classmethod
def convert_datetime_type(
cls, agate_table: agate.Table, col_idx: int
) -> str:
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp"

def quote(self, identifier):
return '`{}`'.format(identifier)
return "`{}`".format(identifier)

def list_relations_without_caching(
self, schema_relation: MariaDBRelation
) -> List[MariaDBRelation]:
kwargs = {'schema_relation': schema_relation}
kwargs = {"schema_relation": schema_relation}
try:
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
kwargs=kwargs
)
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
except dbt.exceptions.RuntimeException as e:
errmsg = getattr(e, 'msg', '')
errmsg = getattr(e, "msg", "")
if f"MariaDB database '{schema_relation}' not found" in errmsg:
return []
else:
Expand All @@ -64,13 +59,11 @@ def list_relations_without_caching(
raise dbt.exceptions.RuntimeException(
"Invalid value from "
f'"mariadb__list_relations_without_caching({kwargs})", '
f'got {len(row)} values, expected 4'
f"got {len(row)} values, expected 4"
)
_, name, _schema, relation_type = row
relation = self.Relation.create(
schema=_schema,
identifier=name,
type=relation_type
schema=_schema, identifier=name, type=relation_type
)
relations.append(relation)

Expand All @@ -88,9 +81,9 @@ def _get_columns_for_catalog(
for column in columns:
# convert MariaDBColumns into catalog dicts
as_dict = asdict(column)
as_dict['column_name'] = as_dict.pop('column', None)
as_dict['column_type'] = as_dict.pop('dtype')
as_dict['table_database'] = None
as_dict["column_name"] = as_dict.pop("column", None)
as_dict["column_type"] = as_dict.pop("dtype")
as_dict["table_database"] = None
yield as_dict

def get_relation(
Expand All @@ -102,48 +95,58 @@ def get_relation(
return super().get_relation(database, schema, identifier)

def parse_show_columns(
self,
relation: Relation,
raw_rows: List[agate.Row]
self, relation: Relation, raw_rows: List[agate.Row]
) -> List[MariaDBColumn]:
return [MariaDBColumn(
table_database=None,
table_schema=relation.schema,
table_name=relation.name,
table_type=relation.type,
table_owner=None,
table_stats=None,
column=column.column,
column_index=idx,
dtype=column.dtype,
) for idx, column in enumerate(raw_rows)]
return [
MariaDBColumn(
table_database=None,
table_schema=relation.schema,
table_name=relation.name,
table_type=relation.type,
table_owner=None,
table_stats=None,
column=column.column,
column_index=idx,
dtype=column.dtype,
)
for idx, column in enumerate(raw_rows)
]

def get_catalog(self, manifest):
schema_map = self._get_catalog_schemas(manifest)
if len(schema_map) > 1:
dbt.exceptions.raise_compiler_error(
f'Expected only one database in get_catalog, found '
f'{list(schema_map)}'
f"Expected only one database in get_catalog, found "
f"{list(schema_map)}"
)

with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
for info, schemas in schema_map.items():
for schema in schemas:
futures.append(tpe.submit_connected(
self, schema,
self._get_one_catalog, info, [schema], manifest
))
futures.append(
tpe.submit_connected(
self,
schema,
self._get_one_catalog,
info,
[schema],
manifest,
)
)
catalogs, exceptions = catch_as_completed(futures)
return catalogs, exceptions

def _get_one_catalog(
self, information_schema, schemas, manifest,
self,
information_schema,
schemas,
manifest,
) -> agate.Table:
if len(schemas) != 1:
dbt.exceptions.raise_compiler_error(
f'Expected only one schema in mariadb _get_one_catalog, found '
f'{schemas}'
f"Expected only one schema in mariadb _get_one_catalog, found "
f"{schemas}"
)

database = information_schema.database
Expand All @@ -153,14 +156,11 @@ def _get_one_catalog(
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", relation)
columns.extend(self._get_columns_for_catalog(relation))
return agate.Table.from_object(
columns, column_types=DEFAULT_TYPE_TESTER
)
return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)

def check_schema_exists(self, database, schema):
results = self.execute_macro(
LIST_SCHEMAS_MACRO_NAME,
kwargs={'database': database}
LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}
)

exists = True if schema in [row[0] for row in results] else False
Expand All @@ -174,13 +174,13 @@ def update_column_sql(
clause: str,
where_clause: Optional[str] = None,
) -> str:
clause = f'update {dst_name} set {dst_column} = {clause}'
clause = f"update {dst_name} set {dst_column} = {clause}"
if where_clause is not None:
clause += f' where {where_clause}'
clause += f" where {where_clause}"
return clause

def timestamp_add_sql(
self, add_to: str, number: int = 1, interval: str = 'hour'
self, add_to: str, number: int = 1, interval: str = "hour"
) -> str:
# for backwards compatibility, we're compelled to set some sort of
# default. A lot of searching has lead me to believe that the
Expand All @@ -189,11 +189,14 @@ def timestamp_add_sql(
return f"date_add({add_to}, interval {number} {interval})"

def string_add_sql(
self, add_to: str, value: str, location='append',
self,
add_to: str,
value: str,
location="append",
) -> str:
if location == 'append':
if location == "append":
return f"concat({add_to}, '{value}')"
elif location == 'prepend':
elif location == "prepend":
return f"concat({value}, '{add_to}')"
else:
raise dbt.exceptions.RuntimeException(
Expand All @@ -216,15 +219,15 @@ def get_rows_different_sql(

alias_a = "A"
alias_b = "B"
columns_csv_a = ', '.join([f"{alias_a}.{name}" for name in names])
columns_csv_b = ', '.join([f"{alias_b}.{name}" for name in names])
columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names])
columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names])
join_condition = " AND ".join(
[f"{alias_a}.{name} = {alias_b}.{name}" for name in names]
)
first_column = names[0]

# There is no EXCEPT or MINUS operator, so we need to simulate it
COLUMNS_EQUAL_SQL = '''
COLUMNS_EQUAL_SQL = """
SELECT
row_count_diff.difference as row_count_difference,
diff_count.num_missing as num_mismatched
Expand Down Expand Up @@ -259,7 +262,7 @@ def get_rows_different_sql(
) as missing
) as diff_count ON row_count_diff.id = diff_count.id
'''.strip()
""".strip()

sql = COLUMNS_EQUAL_SQL.format(
alias_a=alias_a,
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/mariadb/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MariaDBIncludePolicy(Policy):
class MariaDBRelation(BaseRelation):
quote_policy: MariaDBQuotePolicy = MariaDBQuotePolicy()
include_policy: MariaDBIncludePolicy = MariaDBIncludePolicy()
quote_character: str = '`'
quote_character: str = "`"

def __post_init__(self):
if self.database != self.schema and self.database:
Expand Down
5 changes: 2 additions & 3 deletions dbt/adapters/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@


Plugin = AdapterPlugin(
adapter=MySQLAdapter,
credentials=MySQLCredentials,
include_path=mysql.PACKAGE_PATH)
adapter=MySQLAdapter, credentials=MySQLCredentials, include_path=mysql.PACKAGE_PATH
)
Loading

0 comments on commit 509440e

Please sign in to comment.