Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any:

class ParsableSql(PydanticModel):
sql: str
transaction: t.Optional[bool] = None

_parsed: t.Optional[exp.Expression] = None
_parsed_dialect: t.Optional[str] = None
Expand Down
29 changes: 25 additions & 4 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def render_pre_statements(
expand: t.Iterable[str] = tuple(),
deployability_index: t.Optional[DeployabilityIndex] = None,
engine_adapter: t.Optional[EngineAdapter] = None,
inside_transaction: t.Optional[bool] = True,
**kwargs: t.Any,
) -> t.List[exp.Expression]:
"""Renders pre-statements for a model.
Expand All @@ -384,7 +385,11 @@ def render_pre_statements(
The list of rendered expressions.
"""
return self._render_statements(
self.pre_statements,
[
stmt
for stmt in self.pre_statements
if stmt.args.get("transaction", True) == inside_transaction
],
start=start,
end=end,
execution_time=execution_time,
Expand All @@ -405,6 +410,7 @@ def render_post_statements(
expand: t.Iterable[str] = tuple(),
deployability_index: t.Optional[DeployabilityIndex] = None,
engine_adapter: t.Optional[EngineAdapter] = None,
inside_transaction: t.Optional[bool] = True,
**kwargs: t.Any,
) -> t.List[exp.Expression]:
"""Renders post-statements for a model.
Expand All @@ -420,13 +426,18 @@ def render_post_statements(
that depend on materialized tables. Model definitions are inlined and can thus be run end to
end on the fly.
deployability_index: Determines snapshots that are deployable in the context of this render.
inside_transaction: Whether to render hooks with transaction=True (inside) or transaction=False (outside).
kwargs: Additional kwargs to pass to the renderer.

Returns:
The list of rendered expressions.
"""
return self._render_statements(
self.post_statements,
[
stmt
for stmt in self.post_statements
if stmt.args.get("transaction", True) == inside_transaction
],
start=start,
end=end,
execution_time=execution_time,
Expand Down Expand Up @@ -567,6 +578,8 @@ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]:
result = []
for v in value:
parsed = v.parse(self.dialect)
if getattr(v, "transaction", None) is not None:
parsed.set("transaction", v.transaction)
if not isinstance(parsed, exp.Semicolon):
result.append(parsed)
return result
Expand Down Expand Up @@ -2592,9 +2605,17 @@ def _create_model(
if statement_field in kwargs:
# Macros extracted from these statements need to be treated as metadata only
is_metadata = statement_field == "on_virtual_update"
statements.extend((stmt, is_metadata) for stmt in kwargs[statement_field])
for stmt in kwargs[statement_field]:
# Extract the expression if it's ParsableSql already
expr = stmt.parse(dialect) if isinstance(stmt, ParsableSql) else stmt
statements.append((expr, is_metadata))
kwargs[statement_field] = [
ParsableSql.from_parsed_expression(stmt, dialect, use_meta_sql=use_original_sql)
# this to retain the transaction information
stmt
if isinstance(stmt, ParsableSql)
else ParsableSql.from_parsed_expression(
stmt, dialect, use_meta_sql=use_original_sql
)
for stmt in kwargs[statement_field]
]

Expand Down
67 changes: 54 additions & 13 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,13 +750,19 @@ def _evaluate_snapshot(
**render_statements_kwargs
)

evaluation_strategy = _evaluation_strategy(snapshot, adapter)
evaluation_strategy.run_pre_statements(
snapshot=snapshot,
render_kwargs={**render_statements_kwargs, "inside_transaction": False},
)

with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)),
):
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
evaluation_strategy.run_pre_statements(
snapshot=snapshot, render_kwargs=render_statements_kwargs
snapshot=snapshot,
render_kwargs={**render_statements_kwargs, "inside_transaction": True},
)

if not target_table_exists or (model.is_seed and not snapshot.intervals):
Expand Down Expand Up @@ -828,10 +834,16 @@ def _evaluate_snapshot(
)

evaluation_strategy.run_post_statements(
snapshot=snapshot, render_kwargs=render_statements_kwargs
snapshot=snapshot,
render_kwargs={**render_statements_kwargs, "inside_transaction": True},
)

return wap_id
evaluation_strategy.run_post_statements(
snapshot=snapshot,
render_kwargs={**render_statements_kwargs, "inside_transaction": False},
)

return wap_id

def create_snapshot(
self,
Expand Down Expand Up @@ -865,6 +877,11 @@ def create_snapshot(
deployability_index=deployability_index,
)

evaluation_strategy = _evaluation_strategy(snapshot, adapter)
evaluation_strategy.run_pre_statements(
snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False}
)

with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)),
Expand Down Expand Up @@ -896,6 +913,10 @@ def create_snapshot(
dry_run=True,
)

evaluation_strategy.run_post_statements(
snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False}
)

if on_complete is not None:
on_complete(snapshot)

Expand Down Expand Up @@ -1097,6 +1118,11 @@ def _migrate_snapshot(
)
target_table_name = snapshot.table_name()

evaluation_strategy = _evaluation_strategy(snapshot, adapter)
evaluation_strategy.run_pre_statements(
snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False}
)

with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
Expand Down Expand Up @@ -1134,6 +1160,10 @@ def _migrate_snapshot(
dry_run=True,
)

evaluation_strategy.run_post_statements(
snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False}
)

# Retry in case when the table is migrated concurrently from another plan application
@retry(
reraise=True,
Expand Down Expand Up @@ -1454,7 +1484,8 @@ def _execute_create(
}
if run_pre_post_statements:
evaluation_strategy.run_pre_statements(
snapshot=snapshot, render_kwargs=create_render_kwargs
snapshot=snapshot,
render_kwargs={**create_render_kwargs, "inside_transaction": True},
)
evaluation_strategy.create(
table_name=table_name,
Expand All @@ -1471,7 +1502,8 @@ def _execute_create(
)
if run_pre_post_statements:
evaluation_strategy.run_post_statements(
snapshot=snapshot, render_kwargs=create_render_kwargs
snapshot=snapshot,
render_kwargs={**create_render_kwargs, "inside_transaction": True},
)

def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool:
Expand Down Expand Up @@ -2944,12 +2976,20 @@ def append(
)

def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
# in dbt custom materialisations it's up to the user when to run the pre hooks
pass
# in dbt custom materialisations it's up to the user to run the pre hooks inside the transaction
if not render_kwargs.get("inside_transaction", True):
super().run_pre_statements(
snapshot=snapshot,
render_kwargs=render_kwargs,
)

def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
# in dbt custom materialisations it's up to the user when to run the post hooks
pass
# in dbt custom materialisations it's up to the user to run the post hooks inside the transaction
if not render_kwargs.get("inside_transaction", True):
super().run_post_statements(
snapshot=snapshot,
render_kwargs=render_kwargs,
)

def _execute_materialization(
self,
Expand Down Expand Up @@ -2985,14 +3025,15 @@ def _execute_materialization(
"sql": str(query_or_df),
"is_first_insert": is_first_insert,
"create_only": create_only,
# FIXME: Add support for transaction=False
"pre_hooks": [
AttributeDict({"sql": s.this.this, "transaction": True})
AttributeDict({"sql": s.this.this, "transaction": transaction})
for s in model.pre_statements
if (transaction := s.args.get("transaction", True))
],
"post_hooks": [
AttributeDict({"sql": s.this.this, "transaction": True})
AttributeDict({"sql": s.this.this, "transaction": transaction})
for s in model.post_statements
if (transaction := s.args.get("transaction", True))
],
"model_instance": model,
**kwargs,
Expand Down
13 changes: 10 additions & 3 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlmesh.core.config.base import UpdateStrategy
from sqlmesh.core.config.common import VirtualEnvironmentMode
from sqlmesh.core.model import Model
from sqlmesh.core.model.common import ParsableSql
from sqlmesh.core.node import DbtNodeInfo
from sqlmesh.dbt.column import (
ColumnConfig,
Expand Down Expand Up @@ -87,7 +88,7 @@ class Hook(DbtConfig):
"""

sql: SqlStr
transaction: bool = True # TODO not yet supported
transaction: bool = True

_sql_validator = sql_str_validator

Expand Down Expand Up @@ -339,8 +340,14 @@ def sqlmesh_model_kwargs(
),
"jinja_macros": jinja_macros,
"path": self.path,
"pre_statements": [d.jinja_statement(hook.sql) for hook in self.pre_hook],
"post_statements": [d.jinja_statement(hook.sql) for hook in self.post_hook],
"pre_statements": [
ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction)
for hook in self.pre_hook
],
"post_statements": [
ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction)
for hook in self.post_hook
],
"tags": self.tags,
"physical_schema_mapping": context.sqlmesh_config.physical_schema_mapping,
"default_catalog": context.target.database,
Expand Down
20 changes: 10 additions & 10 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3232,11 +3232,11 @@ def test_create_post_statements_use_non_deployable_table(
evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable())

call_args = adapter_mock.execute.call_args_list
pre_calls = call_args[0][0][0]
pre_calls = call_args[1][0][0]
assert len(pre_calls) == 1
assert pre_calls[0].sql(dialect="postgres") == expected_call

post_calls = call_args[1][0][0]
post_calls = call_args[2][0][0]
assert len(post_calls) == 1
assert post_calls[0].sql(dialect="postgres") == expected_call

Expand Down Expand Up @@ -3294,11 +3294,11 @@ def model_with_statements(context, **kwargs):
expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}__dev" /* db.test_model */("id")'

call_args = adapter_mock.execute.call_args_list
pre_calls = call_args[0][0][0]
pre_calls = call_args[1][0][0]
assert len(pre_calls) == 1
assert pre_calls[0].sql(dialect="postgres") == expected_call

post_calls = call_args[1][0][0]
post_calls = call_args[2][0][0]
assert len(post_calls) == 1
assert post_calls[0].sql(dialect="postgres") == expected_call

Expand Down Expand Up @@ -3356,14 +3356,14 @@ def create_log_table(evaluator, view_name):
)

call_args = adapter_mock.execute.call_args_list
post_calls = call_args[1][0][0]
post_calls = call_args[2][0][0]
assert len(post_calls) == 1
assert (
post_calls[0].sql(dialect="postgres")
== f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a")'
)

on_virtual_update_calls = call_args[2][0][0]
on_virtual_update_calls = call_args[4][0][0]
assert (
on_virtual_update_calls[0].sql(dialect="postgres")
== 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"'
Expand Down Expand Up @@ -3441,7 +3441,7 @@ def model_with_statements(context, **kwargs):
)

call_args = adapter_mock.execute.call_args_list
on_virtual_update_call = call_args[2][0][0][0]
on_virtual_update_call = call_args[4][0][0][0]
assert (
on_virtual_update_call.sql(dialect="postgres")
== 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model_3" /* db.test_model_3 */("id")'
Expand Down Expand Up @@ -4187,11 +4187,11 @@ def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot):
assert view_args[1][0][0] == "test_schema__test_env.test_model"

call_args = engine_adapters["secondary"].execute.call_args_list
pre_calls = call_args[0][0][0]
pre_calls = call_args[1][0][0]
assert len(pre_calls) == 1
assert pre_calls[0].sql(dialect="postgres") == expected_call

post_calls = call_args[1][0][0]
post_calls = call_args[2][0][0]
assert len(post_calls) == 1
assert post_calls[0].sql(dialect="postgres") == expected_call

Expand Down Expand Up @@ -4459,7 +4459,7 @@ def model_with_statements(context, **kwargs):

# For the pre/post statements verify the model-specific gateway was used
engine_adapters["default"].execute.assert_called_once()
assert len(engine_adapters["secondary"].execute.call_args_list) == 2
assert len(engine_adapters["secondary"].execute.call_args_list) == 4

# Validate that the get_catalog_type method was called only on the secondary engine from the macro evaluator
engine_adapters["default"].get_catalog_type.assert_not_called()
Expand Down
Loading