Skip to content
This repository has been archived by the owner on Apr 8, 2024. It is now read-only.

Commit

Permalink
Avoid jinja processing of executed SQL (#157)
Browse files Browse the repository at this point in the history
* Avoid jinja processing by not calling rendering function

* Avoid jinja processing by not generating a node at all

* Add test with jinjaful write_to_source
  • Loading branch information
chamini2 authored Feb 24, 2022
1 parent 8f17b4f commit 3e5c1de
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 94 deletions.
64 changes: 0 additions & 64 deletions src/faldbt/cp/parser/sql.py

This file was deleted.

35 changes: 5 additions & 30 deletions src/faldbt/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import Manifest
from dbt.parser.manifest import process_node
from dbt.logger import GLOBAL_LOGGER as logger

from . import parse
Expand All @@ -23,18 +22,14 @@
import sqlalchemy
from sqlalchemy.sql.ddl import CreateTable
from sqlalchemy.sql import Insert
from sqlalchemy.sql.schema import MetaData


DBT_V1 = dbt.semver.VersionSpecifier.from_version_string("1.0.0")
DBT_VCURRENT = dbt.version.get_installed_version()

if DBT_VCURRENT.compare(DBT_V1) >= 0:
from dbt.parser.sql import SqlBlockParser
from dbt.contracts.graph.parsed import ParsedModelNode, ParsedSourceDefinition
from dbt.contracts.sql import ResultTable, RemoteRunResult
else:
from faldbt.cp.parser.sql import SqlBlockParser
from faldbt.cp.contracts.graph.parsed import ParsedModelNode, ParsedSourceDefinition
from faldbt.cp.contracts.sql import ResultTable, RemoteRunResult

Expand Down Expand Up @@ -75,23 +70,6 @@ def register_adapters(config: RuntimeConfig):
adapters_factory.register_adapter(config)


def _get_operation_node(manifest: Manifest, project_path, profiles_dir, sql):

config = parse.get_dbt_config(project_path, profiles_dir)
block_parser = SqlBlockParser(
project=config,
manifest=manifest,
root_project=config,
)

# NOTE: nodes get registered to the manifest automatically,
# HACK: we need to include uniqueness (UUID4) to avoid clashes
name = "SQL:" + str(hash(sql)) + ":" + str(uuid4())
sql_node = block_parser.parse_remote(sql, name)
process_node(config, manifest, sql_node)
return sql_node


# NOTE: Once we get an adapter, we must call `connection_for` or `connection_named` to use it
def _get_adapter(project_path: str, profiles_dir: str):
config = parse.get_dbt_config(project_path, profiles_dir)
Expand All @@ -103,13 +81,14 @@ def _get_adapter(project_path: str, profiles_dir: str):
def _execute_sql(
manifest: Manifest, project_path: str, profiles_dir: str, sql: str
) -> Tuple[AdapterResponse, RemoteRunResult]:
node = _get_operation_node(manifest, project_path, profiles_dir, sql)
adapter = _get_adapter(project_path, profiles_dir)

logger.debug("Running query\n{}", sql)

# HACK: we need to include uniqueness (UUID4) to avoid clashes
name = "SQL:" + str(hash(sql)) + ":" + str(uuid4())
result = None
with adapter.connection_for(node):
with adapter.connection_named(name, node=None):
adapter.connections.begin()
response, execute_result = adapter.execute(sql, fetch=True)

Expand All @@ -121,7 +100,7 @@ def _execute_sql(
result = RemoteRunResult(
raw_sql=sql,
compiled_sql=sql,
node=node,
node=None,
table=table,
timing=[],
logs=[],
Expand Down Expand Up @@ -177,7 +156,7 @@ def write_target(
project_path: str,
profiles_dir: str,
target: Union[ParsedModelNode, ParsedSourceDefinition],
dtype=None
dtype=None,
) -> RemoteRunResult:
adapter = _get_adapter(project_path, profiles_dir)

Expand All @@ -190,10 +169,6 @@ def write_target(

column_names: List[str] = list(data.columns)

# Escape { and } in row data
data = data.replace('{', r'\{', regex=True)
data = data.replace('}', r'\}', regex=True)

rows = data.to_records(index=False)
row_dicts = list(map(lambda row: dict(zip(column_names, row)), rows))

Expand Down
1 change: 1 addition & 0 deletions tests/mock/models/schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ sources:
schema: public
tables:
- name: single_col
- name: sql_col

models:
- name: model_with_scripts
Expand Down
1 change: 1 addition & 0 deletions tests/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
profiles_dir = os.path.join(Path.cwd(), "tests/mock/mockProfile")
project_dir = os.path.join(Path.cwd(), "tests/mock")


def test_initialize():
faldbt = FalDbt(
profiles_dir=profiles_dir,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fal import FalDbt
from pathlib import Path
import os
import pandas as pd

profiles_dir = os.path.join(Path.cwd(), "tests/mock/mockProfile")
project_dir = os.path.join(Path.cwd(), "tests/mock")


# https://github.com/fal-ai/fal/issues/154
def test_write_to_source_not_processing_jinja():
faldbt = FalDbt(
profiles_dir=profiles_dir,
project_dir=project_dir,
)

df = pd.DataFrame({"sql": [r"SELECT 1 FROM {{ wrong jinja }}"]})

faldbt.write_to_source(df, "test_sources", "sql_col")

# TODO: look at df data
df = faldbt.source("test_sources", "sql_col")
assert df.sql.get(0) == r"SELECT 1 FROM {{ wrong jinja }}"

0 comments on commit 3e5c1de

Please sign in to comment.