diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index d2b9a11c08..9e117b56fb 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -663,6 +663,7 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: class ParsableSql(PydanticModel): sql: str + transaction: t.Optional[bool] = None _parsed: t.Optional[exp.Expression] = None _parsed_dialect: t.Optional[str] = None diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index f81dae004b..0a20ab23b2 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -363,6 +363,7 @@ def render_pre_statements( expand: t.Iterable[str] = tuple(), deployability_index: t.Optional[DeployabilityIndex] = None, engine_adapter: t.Optional[EngineAdapter] = None, + inside_transaction: t.Optional[bool] = True, **kwargs: t.Any, ) -> t.List[exp.Expression]: """Renders pre-statements for a model. @@ -384,7 +385,11 @@ def render_pre_statements( The list of rendered expressions. """ return self._render_statements( - self.pre_statements, + [ + stmt + for stmt in self.pre_statements + if stmt.args.get("transaction", True) == inside_transaction + ], start=start, end=end, execution_time=execution_time, @@ -405,6 +410,7 @@ def render_post_statements( expand: t.Iterable[str] = tuple(), deployability_index: t.Optional[DeployabilityIndex] = None, engine_adapter: t.Optional[EngineAdapter] = None, + inside_transaction: t.Optional[bool] = True, **kwargs: t.Any, ) -> t.List[exp.Expression]: """Renders post-statements for a model. @@ -420,13 +426,18 @@ def render_post_statements( that depend on materialized tables. Model definitions are inlined and can thus be run end to end on the fly. deployability_index: Determines snapshots that are deployable in the context of this render. + inside_transaction: Whether to render hooks with transaction=True (inside) or transaction=False (outside). kwargs: Additional kwargs to pass to the renderer. Returns: The list of rendered expressions. """ return self._render_statements( - self.post_statements, + [ + stmt + for stmt in self.post_statements + if stmt.args.get("transaction", True) == inside_transaction + ], start=start, end=end, execution_time=execution_time, @@ -567,6 +578,8 @@ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]: result = [] for v in value: parsed = v.parse(self.dialect) + if getattr(v, "transaction", None) is not None: + parsed.set("transaction", v.transaction) if not isinstance(parsed, exp.Semicolon): result.append(parsed) return result @@ -2592,9 +2605,17 @@ def _create_model( if statement_field in kwargs: # Macros extracted from these statements need to be treated as metadata only is_metadata = statement_field == "on_virtual_update" - statements.extend((stmt, is_metadata) for stmt in kwargs[statement_field]) + for stmt in kwargs[statement_field]: + # Extract the expression if it's ParsableSql already + expr = stmt.parse(dialect) if isinstance(stmt, ParsableSql) else stmt + statements.append((expr, is_metadata)) kwargs[statement_field] = [ - ParsableSql.from_parsed_expression(stmt, dialect, use_meta_sql=use_original_sql) + # this to retain the transaction information + stmt + if isinstance(stmt, ParsableSql) + else ParsableSql.from_parsed_expression( + stmt, dialect, use_meta_sql=use_original_sql + ) for stmt in kwargs[statement_field] ] diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 2676709d85..773010d673 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -750,13 +750,19 @@ def _evaluate_snapshot( **render_statements_kwargs ) + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + evaluation_strategy.run_pre_statements( + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": False}, + ) + with ( adapter.transaction(), adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)), ): - evaluation_strategy = _evaluation_strategy(snapshot, adapter) evaluation_strategy.run_pre_statements( - snapshot=snapshot, render_kwargs=render_statements_kwargs + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": True}, ) if not target_table_exists or (model.is_seed and not snapshot.intervals): @@ -828,10 +834,16 @@ def _evaluate_snapshot( ) evaluation_strategy.run_post_statements( - snapshot=snapshot, render_kwargs=render_statements_kwargs + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": True}, ) - return wap_id + evaluation_strategy.run_post_statements( + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": False}, + ) + + return wap_id def create_snapshot( self, @@ -865,6 +877,11 @@ def create_snapshot( deployability_index=deployability_index, ) + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + evaluation_strategy.run_pre_statements( + snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False} + ) + with ( adapter.transaction(), adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)), @@ -896,6 +913,10 @@ def create_snapshot( dry_run=True, ) + evaluation_strategy.run_post_statements( + snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False} + ) + if on_complete is not None: on_complete(snapshot) @@ -1097,6 +1118,11 @@ def _migrate_snapshot( ) target_table_name = snapshot.table_name() + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + evaluation_strategy.run_pre_statements( + snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False} + ) + with ( adapter.transaction(), adapter.session(snapshot.model.render_session_properties(**render_kwargs)), @@ -1134,6 +1160,10 @@ def _migrate_snapshot( dry_run=True, ) + evaluation_strategy.run_post_statements( + snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False} + ) + # Retry in case when the table is migrated concurrently from another plan application @retry( reraise=True, @@ -1454,7 +1484,8 @@ def _execute_create( } if run_pre_post_statements: evaluation_strategy.run_pre_statements( - snapshot=snapshot, render_kwargs=create_render_kwargs + snapshot=snapshot, + render_kwargs={**create_render_kwargs, "inside_transaction": True}, ) evaluation_strategy.create( table_name=table_name, @@ -1471,7 +1502,8 @@ def _execute_create( ) if run_pre_post_statements: evaluation_strategy.run_post_statements( - snapshot=snapshot, render_kwargs=create_render_kwargs + snapshot=snapshot, + render_kwargs={**create_render_kwargs, "inside_transaction": True}, ) def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool: @@ -2944,12 +2976,20 @@ def append( ) def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: - # in dbt custom materialisations it's up to the user when to run the pre hooks - pass + # in dbt custom materialisations it's up to the user to run the pre hooks inside the transaction + if not render_kwargs.get("inside_transaction", True): + super().run_pre_statements( + snapshot=snapshot, + render_kwargs=render_kwargs, + ) def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: - # in dbt custom materialisations it's up to the user when to run the post hooks - pass + # in dbt custom materialisations it's up to the user to run the post hooks inside the transaction + if not render_kwargs.get("inside_transaction", True): + super().run_post_statements( + snapshot=snapshot, + render_kwargs=render_kwargs, + ) def _execute_materialization( self, @@ -2985,14 +3025,15 @@ def _execute_materialization( "sql": str(query_or_df), "is_first_insert": is_first_insert, "create_only": create_only, - # FIXME: Add support for transaction=False "pre_hooks": [ - AttributeDict({"sql": s.this.this, "transaction": True}) + AttributeDict({"sql": s.this.this, "transaction": transaction}) for s in model.pre_statements + if (transaction := s.args.get("transaction", True)) ], "post_hooks": [ - AttributeDict({"sql": s.this.this, "transaction": True}) + AttributeDict({"sql": s.this.this, "transaction": transaction}) for s in model.post_statements + if (transaction := s.args.get("transaction", True)) ], "model_instance": model, **kwargs, diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 7c7e9e2e76..0c719ebb88 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -13,6 +13,7 @@ from sqlmesh.core.config.base import UpdateStrategy from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.core.model import Model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.node import DbtNodeInfo from sqlmesh.dbt.column import ( ColumnConfig, @@ -87,7 +88,7 @@ class Hook(DbtConfig): """ sql: SqlStr - transaction: bool = True # TODO not yet supported + transaction: bool = True _sql_validator = sql_str_validator @@ -339,8 +340,14 @@ def sqlmesh_model_kwargs( ), "jinja_macros": jinja_macros, "path": self.path, - "pre_statements": [d.jinja_statement(hook.sql) for hook in self.pre_hook], - "post_statements": [d.jinja_statement(hook.sql) for hook in self.post_hook], + "pre_statements": [ + ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction) + for hook in self.pre_hook + ], + "post_statements": [ + ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction) + for hook in self.post_hook + ], "tags": self.tags, "physical_schema_mapping": context.sqlmesh_config.physical_schema_mapping, "default_catalog": context.target.database, diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 68061544a8..c0a7a01b51 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -3232,11 +3232,11 @@ def test_create_post_statements_use_non_deployable_table( evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) call_args = adapter_mock.execute.call_args_list - pre_calls = call_args[0][0][0] + pre_calls = call_args[1][0][0] assert len(pre_calls) == 1 assert pre_calls[0].sql(dialect="postgres") == expected_call - post_calls = call_args[1][0][0] + post_calls = call_args[2][0][0] assert len(post_calls) == 1 assert post_calls[0].sql(dialect="postgres") == expected_call @@ -3294,11 +3294,11 @@ def model_with_statements(context, **kwargs): expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}__dev" /* db.test_model */("id")' call_args = adapter_mock.execute.call_args_list - pre_calls = call_args[0][0][0] + pre_calls = call_args[1][0][0] assert len(pre_calls) == 1 assert pre_calls[0].sql(dialect="postgres") == expected_call - post_calls = call_args[1][0][0] + post_calls = call_args[2][0][0] assert len(post_calls) == 1 assert post_calls[0].sql(dialect="postgres") == expected_call @@ -3356,14 +3356,14 @@ def create_log_table(evaluator, view_name): ) call_args = adapter_mock.execute.call_args_list - post_calls = call_args[1][0][0] + post_calls = call_args[2][0][0] assert len(post_calls) == 1 assert ( post_calls[0].sql(dialect="postgres") == f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a")' ) - on_virtual_update_calls = call_args[2][0][0] + on_virtual_update_calls = call_args[4][0][0] assert ( on_virtual_update_calls[0].sql(dialect="postgres") == 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"' @@ -3441,7 +3441,7 @@ def model_with_statements(context, **kwargs): ) call_args = adapter_mock.execute.call_args_list - on_virtual_update_call = call_args[2][0][0][0] + on_virtual_update_call = call_args[4][0][0][0] assert ( on_virtual_update_call.sql(dialect="postgres") == 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model_3" /* db.test_model_3 */("id")' @@ -4187,11 +4187,11 @@ def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot): assert view_args[1][0][0] == "test_schema__test_env.test_model" call_args = engine_adapters["secondary"].execute.call_args_list - pre_calls = call_args[0][0][0] + pre_calls = call_args[1][0][0] assert len(pre_calls) == 1 assert pre_calls[0].sql(dialect="postgres") == expected_call - post_calls = call_args[1][0][0] + post_calls = call_args[2][0][0] assert len(post_calls) == 1 assert post_calls[0].sql(dialect="postgres") == expected_call @@ -4459,7 +4459,7 @@ def model_with_statements(context, **kwargs): # For the pre/post statements verify the model-specific gateway was used engine_adapters["default"].execute.assert_called_once() - assert len(engine_adapters["secondary"].execute.call_args_list) == 2 + assert len(engine_adapters["secondary"].execute.call_args_list) == 4 # Validate that the get_catalog_type method was called only on the secondary engine from the macro evaluator engine_adapters["default"].get_catalog_type.assert_not_called() diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 0a1091a7fc..dd69f46200 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -2707,3 +2707,180 @@ def test_ignore_source_depends_on_when_also_model(dbt_dummy_postgres_config: Pos } assert model.sqlmesh_model_kwargs(context)["depends_on"] == {"schema.source_b"} + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dbt_hooks_with_transaction_flag(sushi_test_dbt_context: Context): + model_fqn = '"memory"."sushi"."model_with_transaction_hooks"' + assert model_fqn in sushi_test_dbt_context.models + + model = sushi_test_dbt_context.models[model_fqn] + + pre_statements = model.pre_statements_ + assert pre_statements is not None + assert len(pre_statements) >= 3 + + # we need to check the expected SQL but more importantly that the transaction flags are there + assert any( + s.sql == 'JINJA_STATEMENT_BEGIN;\n{{ log("pre-hook") }}\nJINJA_END;' + and s.transaction is True + for s in pre_statements + ) + assert any( + "CREATE TABLE IF NOT EXISTS hook_outside_pre_table" in s.sql and s.transaction is False + for s in pre_statements + ) + assert any( + "CREATE TABLE IF NOT EXISTS shared_hook_table" in s.sql and s.transaction is False + for s in pre_statements + ) + assert any( + "{{ insert_into_shared_hook_table('inside_pre') }}" in s.sql and s.transaction is True + for s in pre_statements + ) + + post_statements = model.post_statements_ + assert post_statements is not None + assert len(post_statements) >= 4 + assert any( + s.sql == 'JINJA_STATEMENT_BEGIN;\n{{ log("post-hook") }}\nJINJA_END;' + and s.transaction is True + for s in post_statements + ) + assert any( + "{{ insert_into_shared_hook_table('inside_post') }}" in s.sql and s.transaction is True + for s in post_statements + ) + assert any( + "CREATE TABLE IF NOT EXISTS hook_outside_post_table" in s.sql and s.transaction is False + for s in post_statements + ) + assert any( + "{{ insert_into_shared_hook_table('after_commit') }}" in s.sql and s.transaction is False + for s in post_statements + ) + + # render_pre_statements with inside_transaction=True should only return inserrt + inside_pre_statements = model.render_pre_statements(inside_transaction=True) + assert len(inside_pre_statements) == 1 + assert ( + inside_pre_statements[0].sql() + == """INSERT INTO "shared_hook_table" ("id", "hook_name", "execution_order", "created_at") VALUES ((SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), 'inside_pre', (SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), NOW())""" + ) + + # while for render_pre_statements with inside_transaction=False the create statements + outside_pre_statements = model.render_pre_statements(inside_transaction=False) + assert len(outside_pre_statements) == 2 + assert "CREATE" in outside_pre_statements[0].sql() + assert "hook_outside_pre_table" in outside_pre_statements[0].sql() + assert "CREATE" in outside_pre_statements[1].sql() + assert "shared_hook_table" in outside_pre_statements[1].sql() + + # similarly for post statements + inside_post_statements = model.render_post_statements(inside_transaction=True) + assert len(inside_post_statements) == 1 + assert ( + inside_post_statements[0].sql() + == """INSERT INTO "shared_hook_table" ("id", "hook_name", "execution_order", "created_at") VALUES ((SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), 'inside_post', (SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), NOW())""" + ) + + outside_post_statements = model.render_post_statements(inside_transaction=False) + assert len(outside_post_statements) == 2 + assert "CREATE" in outside_post_statements[0].sql() + assert "hook_outside_post_table" in outside_post_statements[0].sql() + assert "INSERT" in outside_post_statements[1].sql() + assert "shared_hook_table" in outside_post_statements[1].sql() + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dbt_hooks_with_transaction_flag_execution(sushi_test_dbt_context: Context): + model_fqn = '"memory"."sushi"."model_with_transaction_hooks"' + assert model_fqn in sushi_test_dbt_context.models + + plan = sushi_test_dbt_context.plan(select_models=["sushi.model_with_transaction_hooks"]) + sushi_test_dbt_context.apply(plan) + + result = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM sushi.model_with_transaction_hooks" + ) + assert len(result) == 1 + assert result["id"][0] == 1 + assert result["name"][0] == "test" + + # ensure the outside pre-hook and post-hook table were created + pre_outside = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM hook_outside_pre_table" + ) + assert len(pre_outside) == 1 + assert pre_outside["id"][0] == 1 + assert pre_outside["location"][0] == "outside" + assert pre_outside["execution_order"][0] == 1 + + post_outside = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM hook_outside_post_table" + ) + assert len(post_outside) == 1 + assert post_outside["id"][0] == 5 + assert post_outside["location"][0] == "outside" + assert post_outside["execution_order"][0] == 5 + + # verify the shared table that was created by before_begin and populated by all hooks + shared_table = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM shared_hook_table ORDER BY execution_order" + ) + assert len(shared_table) == 3 + assert shared_table["execution_order"].is_monotonic_increasing + + # The order of creation and insertion will verify the following order of execution + # 1. before_begin (transaction=false) ran BEFORE the transaction started and created the table + # 2. inside_pre (transaction=true) ran INSIDE the transaction and could insert into the table + # 3. inside_post (transaction=true) ran INSIDE the transaction and could insert into the table (but after pre statement) + # 4. after_commit (transaction=false) ran AFTER the transaction committed + + assert shared_table["id"][0] == 1 + assert shared_table["hook_name"][0] == "inside_pre" + assert shared_table["execution_order"][0] == 1 + + assert shared_table["id"][1] == 2 + assert shared_table["hook_name"][1] == "inside_post" + assert shared_table["execution_order"][1] == 2 + + assert shared_table["id"][2] == 3 + assert shared_table["hook_name"][2] == "after_commit" + assert shared_table["execution_order"][2] == 3 + + # the timestamps also should be monotonically increasing for the same reason + for i in range(len(shared_table) - 1): + assert shared_table["created_at"][i] <= shared_table["created_at"][i + 1] + + # the tables using the alternate syntax should have correct order as well + assert pre_outside["created_at"][0] < shared_table["created_at"][0] + assert post_outside["created_at"][0] > shared_table["created_at"][1] + + # running with execution time one day in the future to simulate a run + tomorrow = datetime.now() + timedelta(days=1) + sushi_test_dbt_context.run( + select_models=["sushi.model_with_transaction_hooks"], execution_time=tomorrow + ) + + # to verify that the transaction information persists in state and is respected + shared_table = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM shared_hook_table ORDER BY execution_order" + ) + + # and the execution order for run is similar + assert shared_table["execution_order"].is_monotonic_increasing + assert shared_table["id"][3] == 4 + assert shared_table["hook_name"][3] == "inside_pre" + assert shared_table["execution_order"][3] == 4 + + assert shared_table["id"][4] == 5 + assert shared_table["hook_name"][4] == "inside_post" + assert shared_table["execution_order"][4] == 5 + + assert shared_table["id"][5] == 6 + assert shared_table["hook_name"][5] == "after_commit" + assert shared_table["execution_order"][5] == 6 + + for i in range(len(shared_table) - 1): + assert shared_table["created_at"][i] <= shared_table["created_at"][i + 1] diff --git a/tests/fixtures/dbt/sushi_test/macros/insert_hook.sql b/tests/fixtures/dbt/sushi_test/macros/insert_hook.sql new file mode 100644 index 0000000000..aa27a7fe6d --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/insert_hook.sql @@ -0,0 +1,14 @@ +{% macro insert_into_shared_hook_table(hook_name) %} +INSERT INTO shared_hook_table ( + id, + hook_name, + execution_order, + created_at +) +VALUES ( + (SELECT COALESCE(MAX(id), 0) + 1 FROM shared_hook_table), + '{{ hook_name }}', + (SELECT COALESCE(MAX(id), 0) + 1 FROM shared_hook_table), + NOW() +) +{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/models/model_with_transaction_hooks.sql b/tests/fixtures/dbt/sushi_test/models/model_with_transaction_hooks.sql new file mode 100644 index 0000000000..49883f73df --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/model_with_transaction_hooks.sql @@ -0,0 +1,56 @@ +{{ + config( + materialized = 'table', + + pre_hook = [ + { + "sql": " + CREATE TABLE IF NOT EXISTS hook_outside_pre_table AS + SELECT + 1 AS id, + 'outside' AS location, + 1 AS execution_order, + NOW() AS created_at + ", + "transaction": false + }, + + before_begin(" + CREATE TABLE IF NOT EXISTS shared_hook_table ( + id INT, + hook_name VARCHAR, + execution_order INT, + created_at TIMESTAMPTZ + ) + "), + + { + "sql": "{{ insert_into_shared_hook_table('inside_pre') }}", + "transaction": true + } + ], + + post_hook = [ + { + "sql": "{{ insert_into_shared_hook_table('inside_post') }}", + "transaction": true + }, + + { + "sql": " + CREATE TABLE IF NOT EXISTS hook_outside_post_table AS + SELECT + 5 AS id, + 'outside' AS location, + 5 AS execution_order, + NOW() AS created_at + ", + "transaction": false + }, + + after_commit("{{ insert_into_shared_hook_table('after_commit') }}") + ] + ) +}} + +SELECT 1 AS id, 'test' AS name;