Skip to content

Commit

Permalink
Upgraded to dbt-core 1.4. (#146)
Browse files Browse the repository at this point in the history
* Upgraded to dbt-core 1.4.

* Updated CHANGELOG.

* Fixed policy fields definitions for mariadb and mysql5.

* Replaced deprecated dbt.exceptions.raise_compiler_error() with dbt.exceptions.CompilationError.

* Now using dbt.exceptions.DbtDatabaseError insead of dbt.exceptions.DatabaseException.

* Update version

* Update changelog

---------

Co-authored-by: Doug Beatty <[email protected]>
  • Loading branch information
lpezet and dbeatty10 authored Jun 11, 2023
1 parent 509440e commit 3db05eb
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 75 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
## Unreleased (TBD)

### Features
- Support dbt v1.4 ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146))

### Contributors
- [@lpezet](https://github.com/lpezet) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146))

## dbt-mysql 1.1.0 (Feb 5, 2023)

### Features
- Support dbt v1.1 ([#100](https://github.com/dbeatty10/dbt-mysql/pull/100))
- More clear exception for invalid `database` config ([#110](https://github.com/dbeatty10/dbt-mysql/issues/110), [#111](https://github.com/dbeatty10/dbt-mysql/pull/111))
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/mariadb/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.2.0a1"
version = "1.4.0a1"
21 changes: 12 additions & 9 deletions dbt/adapters/mariadb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def __init__(self, **kwargs):
def __post_init__(self):
# Database and schema are treated as the same thing
if self.database is not None and self.database != self.schema:
raise dbt.exceptions.RuntimeException(
raise dbt.exceptions.DbtRuntimeError(
f" schema: {self.schema} \n"
f" database: {self.database} \n"
f"On MariaDB, database must be omitted or have the same value as"
f" schema."
f"On MariaDB, database must be omitted"
f" or have the same value as schema."
)

@property
Expand Down Expand Up @@ -117,7 +117,7 @@ def open(cls, connection):
connection.handle = None
connection.state = "fail"

raise dbt.exceptions.FailedToConnectException(str(e))
raise dbt.exceptions.FailedToConnectError(str(e))

return connection

Expand All @@ -142,19 +142,19 @@ def exception_handler(self, sql):
logger.debug("Failed to release connection!")
pass

raise dbt.exceptions.DatabaseException(str(e).strip()) from e
raise dbt.exceptions.DbtDatabaseError(str(e).strip()) from e

except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.rollback_if_open()
if isinstance(e, dbt.exceptions.RuntimeException):
if isinstance(e, dbt.exceptions.DbtRuntimeError):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
# useful information, so raise it without modification.
raise

raise dbt.exceptions.RuntimeException(e) from e
raise dbt.exceptions.DbtRuntimeError(e) from e

@classmethod
def get_response(cls, cursor) -> AdapterResponse:
Expand All @@ -164,8 +164,11 @@ def get_response(cls, cursor) -> AdapterResponse:
if cursor is not None and cursor.rowcount is not None:
num_rows = cursor.rowcount

# There's no real way to get the status from the mysql-connector-python driver.
# There's no real way to get the status from
# the mysql-connector-python driver.
# So just return the default value.
return AdapterResponse(
_message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code
_message="{} {}".format(code, num_rows),
rows_affected=num_rows,
code=code
)
24 changes: 14 additions & 10 deletions dbt/adapters/mariadb/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def date_function(cls):
return "current_date()"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: agate.Table,
col_idx: int) -> str:
return "timestamp"

def quote(self, identifier):
Expand All @@ -43,8 +44,9 @@ def list_relations_without_caching(
) -> List[MariaDBRelation]:
kwargs = {"schema_relation": schema_relation}
try:
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
except dbt.exceptions.RuntimeException as e:
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME,
kwargs=kwargs)
except dbt.exceptions.DbtRuntimeError as e:
errmsg = getattr(e, "msg", "")
if f"MariaDB database '{schema_relation}' not found" in errmsg:
return []
Expand All @@ -56,7 +58,7 @@ def list_relations_without_caching(
relations = []
for row in results:
if len(row) != 4:
raise dbt.exceptions.RuntimeException(
raise dbt.exceptions.DbtRuntimeError(
"Invalid value from "
f'"mariadb__list_relations_without_caching({kwargs})", '
f"got {len(row)} values, expected 4"
Expand All @@ -69,7 +71,8 @@ def list_relations_without_caching(

return relations

def get_columns_in_relation(self, relation: Relation) -> List[MariaDBColumn]:
def get_columns_in_relation(self,
relation: Relation) -> List[MariaDBColumn]:
rows: List[agate.Row] = super().get_columns_in_relation(relation)
return self.parse_show_columns(relation, rows)

Expand All @@ -89,7 +92,7 @@ def _get_columns_for_catalog(
def get_relation(
self, database: str, schema: str, identifier: str
) -> Optional[BaseRelation]:
if not self.Relation.include_policy.database:
if not self.Relation.get_default_include_policy().database:
database = None

return super().get_relation(database, schema, identifier)
Expand All @@ -115,7 +118,7 @@ def parse_show_columns(
def get_catalog(self, manifest):
schema_map = self._get_catalog_schemas(manifest)
if len(schema_map) > 1:
dbt.exceptions.raise_compiler_error(
raise dbt.exceptions.CompilationError(
f"Expected only one database in get_catalog, found "
f"{list(schema_map)}"
)
Expand Down Expand Up @@ -144,7 +147,7 @@ def _get_one_catalog(
manifest,
) -> agate.Table:
if len(schemas) != 1:
dbt.exceptions.raise_compiler_error(
raise dbt.exceptions.CompilationError(
f"Expected only one schema in mariadb _get_one_catalog, found "
f"{schemas}"
)
Expand All @@ -156,7 +159,8 @@ def _get_one_catalog(
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", relation)
columns.extend(self._get_columns_for_catalog(relation))
return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)
return agate.Table.from_object(columns,
column_types=DEFAULT_TYPE_TESTER)

def check_schema_exists(self, database, schema):
results = self.execute_macro(
Expand Down Expand Up @@ -199,7 +203,7 @@ def string_add_sql(
elif location == "prepend":
return f"concat({value}, '{add_to}')"
else:
raise dbt.exceptions.RuntimeException(
raise dbt.exceptions.DbtRuntimeError(
f'Got an unexpected location value of "{location}"'
)

Expand Down
14 changes: 8 additions & 6 deletions dbt/adapters/mariadb/relation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

from dbt.adapters.base.relation import BaseRelation, Policy
from dbt.exceptions import RuntimeException
from dbt.exceptions import DbtRuntimeError


@dataclass
Expand All @@ -20,21 +20,23 @@ class MariaDBIncludePolicy(Policy):

@dataclass(frozen=True, eq=False, repr=False)
class MariaDBRelation(BaseRelation):
quote_policy: MariaDBQuotePolicy = MariaDBQuotePolicy()
include_policy: MariaDBIncludePolicy = MariaDBIncludePolicy()
quote_policy: MariaDBQuotePolicy = field(
default_factory=lambda: MariaDBQuotePolicy())
include_policy: MariaDBIncludePolicy = field(
default_factory=lambda: MariaDBIncludePolicy())
quote_character: str = "`"

def __post_init__(self):
if self.database != self.schema and self.database:
raise RuntimeException(
raise DbtRuntimeError(
f"Cannot set `database` to '{self.database}' in MariaDB!"
"You can either unset `database`, or make it match `schema`, "
f"currently set to '{self.schema}'"
)

def render(self):
if self.include_policy.database and self.include_policy.schema:
raise RuntimeException(
raise DbtRuntimeError(
"Got a MariaDB relation with schema and database set to "
"include, but only one can be set"
)
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/mysql/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.2.0a1"
version = "1.4.0a1"
17 changes: 10 additions & 7 deletions dbt/adapters/mysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, **kwargs):
def __post_init__(self):
# mysql classifies database and schema as the same thing
if self.database is not None and self.database != self.schema:
raise dbt.exceptions.RuntimeException(
raise dbt.exceptions.DbtRuntimeError(
f" schema: {self.schema} \n"
f" database: {self.database} \n"
f"On MySQL, database must be omitted or have the same value as"
Expand Down Expand Up @@ -113,7 +113,7 @@ def open(cls, connection):
connection.handle = None
connection.state = "fail"

raise dbt.exceptions.FailedToConnectException(str(e))
raise dbt.exceptions.FailedToConnectError(str(e))

return connection

Expand All @@ -138,19 +138,19 @@ def exception_handler(self, sql):
logger.debug("Failed to release connection!")
pass

raise dbt.exceptions.DatabaseException(str(e).strip()) from e
raise dbt.exceptions.DbtDatabaseError(str(e).strip()) from e

except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.rollback_if_open()
if isinstance(e, dbt.exceptions.RuntimeException):
if isinstance(e, dbt.exceptions.DbtRuntimeError):
# during a sql query, an internal to dbt exception was raised.
# this sounds a lot like a signal handler and probably has
# useful information, so raise it without modification.
raise

raise dbt.exceptions.RuntimeException(e) from e
raise dbt.exceptions.DbtRuntimeError(e) from e

@classmethod
def get_response(cls, cursor) -> AdapterResponse:
Expand All @@ -160,8 +160,11 @@ def get_response(cls, cursor) -> AdapterResponse:
if cursor is not None and cursor.rowcount is not None:
num_rows = cursor.rowcount

# There's no real way to get the status from the mysql-connector-python driver.
# There's no real way to get the status from the
# mysql-connector-python driver.
# So just return the default value.
return AdapterResponse(
_message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code
_message="{} {}".format(code, num_rows),
rows_affected=num_rows,
code=code
)
24 changes: 14 additions & 10 deletions dbt/adapters/mysql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def date_function(cls):
return "current_date()"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: agate.Table,
col_idx: int) -> str:
return "timestamp"

def quote(self, identifier):
Expand All @@ -43,8 +44,9 @@ def list_relations_without_caching(
) -> List[MySQLRelation]:
kwargs = {"schema_relation": schema_relation}
try:
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
except dbt.exceptions.RuntimeException as e:
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME,
kwargs=kwargs)
except dbt.exceptions.DbtRuntimeError as e:
errmsg = getattr(e, "msg", "")
if f"MySQL database '{schema_relation}' not found" in errmsg:
return []
Expand All @@ -56,7 +58,7 @@ def list_relations_without_caching(
relations = []
for row in results:
if len(row) != 4:
raise dbt.exceptions.RuntimeException(
raise dbt.exceptions.DbtRuntimeError(
"Invalid value from "
f'"mysql__list_relations_without_caching({kwargs})", '
f"got {len(row)} values, expected 4"
Expand Down Expand Up @@ -89,7 +91,7 @@ def _get_columns_for_catalog(
def get_relation(
self, database: str, schema: str, identifier: str
) -> Optional[BaseRelation]:
if not self.Relation.include_policy.database:
if not self.Relation.get_default_include_policy().database:
database = None

return super().get_relation(database, schema, identifier)
Expand All @@ -116,7 +118,7 @@ def get_catalog(self, manifest):
schema_map = self._get_catalog_schemas(manifest)

if len(schema_map) > 1:
dbt.exceptions.raise_compiler_error(
raise dbt.exceptions.CompilationError(
f"Expected only one database in get_catalog, found "
f"{list(schema_map)}"
)
Expand Down Expand Up @@ -145,7 +147,7 @@ def _get_one_catalog(
manifest,
) -> agate.Table:
if len(schemas) != 1:
dbt.exceptions.raise_compiler_error(
raise dbt.exceptions.CompilationError(
f"Expected only one schema in mysql _get_one_catalog, found "
f"{schemas}"
)
Expand All @@ -157,7 +159,8 @@ def _get_one_catalog(
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", relation)
columns.extend(self._get_columns_for_catalog(relation))
return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)
return agate.Table.from_object(columns,
column_types=DEFAULT_TYPE_TESTER)

def check_schema_exists(self, database, schema):
results = self.execute_macro(
Expand Down Expand Up @@ -200,7 +203,7 @@ def string_add_sql(
elif location == "prepend":
return f"concat({value}, '{add_to}')"
else:
raise dbt.exceptions.RuntimeException(
raise dbt.exceptions.DbtRuntimeError(
f'Got an unexpected location value of "{location}"'
)

Expand All @@ -227,7 +230,8 @@ def get_rows_different_sql(
)
first_column = names[0]

# MySQL doesn't have an EXCEPT or MINUS operator, so we need to simulate it
# MySQL doesn't have an EXCEPT or MINUS operator,
# so we need to simulate it
COLUMNS_EQUAL_SQL = """
WITH
a_except_b as (
Expand Down
14 changes: 8 additions & 6 deletions dbt/adapters/mysql/relation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

from dbt.adapters.base.relation import BaseRelation, Policy
from dbt.exceptions import RuntimeException
from dbt.exceptions import DbtRuntimeError


@dataclass
Expand All @@ -20,21 +20,23 @@ class MySQLIncludePolicy(Policy):

@dataclass(frozen=True, eq=False, repr=False)
class MySQLRelation(BaseRelation):
quote_policy: MySQLQuotePolicy = MySQLQuotePolicy()
include_policy: MySQLIncludePolicy = MySQLIncludePolicy()
quote_policy: MySQLQuotePolicy = field(
default_factory=lambda: MySQLQuotePolicy())
include_policy: MySQLIncludePolicy = field(
default_factory=lambda: MySQLIncludePolicy())
quote_character: str = "`"

def __post_init__(self):
if self.database != self.schema and self.database:
raise RuntimeException(
raise DbtRuntimeError(
f"Cannot set `database` to '{self.database}' in mysql!"
"You can either unset `database`, or make it match `schema`, "
f"currently set to '{self.schema}'"
)

def render(self):
if self.include_policy.database and self.include_policy.schema:
raise RuntimeException(
raise DbtRuntimeError(
"Got a mysql relation with schema and database set to "
"include, but only one can be set"
)
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/mysql5/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.2.0a1"
version = "1.4.0a1"
Loading

0 comments on commit 3db05eb

Please sign in to comment.