Skip to content

Commit

Permalink
add types
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Nov 10, 2023
1 parent 8fdb499 commit 5513248
Show file tree
Hide file tree
Showing 45 changed files with 4,261 additions and 2,377 deletions.
2 changes: 0 additions & 2 deletions .streamlit/config.toml

This file was deleted.

3,896 changes: 2,401 additions & 1,495 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ rich = ">=10"
pydantic = "^1.9.1"
bottle = "^0.12.23"
orjson = "^3.8.0"
fastapi = "^0.85.0"
fastapi = "^0.104.1"
openapi-schema-pydantic = "^1.2.4"
deepmerge = "^1.1.0"
jsonschema = "^3.0"
Expand All @@ -53,8 +53,6 @@ duckcli = { version = "^0.2.1", optional = true }
dbt-duckdb = { version = "^1.5.0", optional = true }
pandas = "^1.5.3"
# Deploy
fal-serverless = "0.6.29"
modal-client = "0.49.2059"
sqlglot = "12.3.0"

[tool.poetry.dev-dependencies]
Expand All @@ -81,8 +79,6 @@ playground = [
"feedparser",
]
deploy = [
"fal-serverless",
"modal-client"
]


Expand Down
21 changes: 10 additions & 11 deletions src/jinjat/core/dbt/dbt_project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys

import dbt.adapters.factory
from dbt.exceptions import CompilationError
from dbt.exceptions import CompilationError, DbtRuntimeError
from pydantic import BaseModel

from jinjat.core.exceptions import ExecuteSqlFailure
Expand Down Expand Up @@ -134,12 +134,8 @@ def parse_project(self, init: bool = False) -> None:
self.config, self.config.load_dependencies(), self.adapter.connections.set_query_header
)
# endpatched (https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/parser/manifest.py#L545)
try:
register_adapter(self.config)
self.dbt = project_parser.load()
except CompilationError as e:
logger().error(f"Encountered an error loading dbt module:\n{e}")
sys.exit(1)
register_adapter(self.config)
self.dbt = project_parser.load()
self.dbt.build_flat_graph()
project_parser.save_macros_to_adapter(self.adapter)
self._sql_parser = None
Expand Down Expand Up @@ -254,12 +250,14 @@ def safe_parse_project(self, reinit: bool = False) -> None:
raise parse_error
self.write_manifest_artifact()

def write_manifest_artifact(self) -> None:
"""Write a manifest.json to disk"""
artifact_path = os.path.join(
def get_manifest_file_path(self):
return os.path.join(
self.config.project_root, self.config.target_path, MANIFEST_FILE_NAME
)
self.dbt.write(artifact_path)

def write_manifest_artifact(self) -> None:
"""Write a manifest.json to disk"""
self.dbt.write(self.get_manifest_file_path())

def clear_caches(self) -> None:
"""Clear least recently used caches and reinstantiable container objects"""
Expand Down Expand Up @@ -337,6 +335,7 @@ def execute_sql(self, raw_sql: str, ctx: DbtQueryRequestContext, fetch: bool = T
compiled_sql = compiled_node.compiled_sql
except Exception as e:
raise ExecuteSqlFailure(raw_sql, None, e)
logger().debug(f"Executing:\n ${raw_sql}")
try:
table = self.adapter_execute(compiled_sql, fetch=fetch)
except Exception as e:
Expand Down
13 changes: 8 additions & 5 deletions src/jinjat/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from pydantic import BaseModel, validator
from starlette.requests import Request

from jinjat.core.util.api import JinjatErrorContainer, JinjatError

LIMIT_QUERY_PARAM = '_limit'
JSON_COLUMNS_QUERY_PARAM = '_json_columns'

Expand Down Expand Up @@ -87,7 +89,7 @@ async def generate_dbt_context_from_request(request: Request, openapi: dict = No
else:
body = None
return DbtQueryRequestContext(method=request.method, body=body,
headers=request.headers,
headers=dict(request.headers.items()),
params=request.path_params, query=request.query_params)


Expand Down Expand Up @@ -129,11 +131,12 @@ class JinjatColumn(BaseModel):
class JinjatExecutionResult(BaseModel):
"""Interface for execution results, this keeps us 1 layer removed from dbt interfaces which may change"""
request: DbtQueryRequestContext
adapter_response: JinjatAdapterResponse
columns: List[JinjatColumn]
data: Any
adapter_response: Optional[JinjatAdapterResponse]
columns: Optional[List[JinjatColumn]]
data: Optional[Any]
raw_sql: str
compiled_sql: str
compiled_sql: Optional[str]
error: Optional[str]

@staticmethod
def from_dbt(ctx: DbtQueryRequestContext, result: DbtAdapterExecutionResult,
Expand Down
50 changes: 25 additions & 25 deletions src/jinjat/core/routes/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ class DbtAdhocQueryRequest(BaseModel):


@app.get(
"/manifest.json"
"/manifest.json", response_model=None
)
async def execute_manifest_query(
request: Request,
response: Response,
jmespath: Optional[str] = None,
# x_dbt_project: Optional[str] = Header(default=None),
) -> any:
) -> dict:
"""Execute dbt SQL against a registered project as determined by X-dbt-Project header"""
dbt_container: DbtProjectContainer = request.app.state.dbt_project_container
project = dbt_container.get_project(None)
Expand Down Expand Up @@ -81,7 +81,17 @@ async def execute_sql(
if project is None:
raise jinjat_project_not_found_error()

dbt_result = await _execute_jinjat_query(project, project.execute_sql, body.sql, body.request, body.limit)
try:
dbt_result = await _execute_jinjat_query(project, project.execute_sql, body.sql, body.request, body.limit)
except ExecuteSqlFailure as execution_err:
raise JinjatErrorContainer(
status_code=status.HTTP_400_BAD_REQUEST,
errors=[JinjatError(
code=JinjatErrorCode.ExecuteSqlFailure,
message=str(execution_err.dbt_exception),
error=execution_err.to_model()
)]
)
json_cols = json.loads(request.query_params.get(JSON_COLUMNS_QUERY_PARAM) or '[]')

return JinjatExecutionResult.from_dbt(body.request, dbt_result, json_columns=json_cols)
Expand All @@ -97,28 +107,18 @@ async def _execute_jinjat_query(project: DbtProject, execute_function, query: st

loop = asyncio.get_running_loop()

try:
result = await loop.run_in_executor(
None, project.fn_threaded_conn(execute_function, final_query, ctx, fetch)
)
if include_total:
if limit is not None and len(result.table.rows) < limit:
size = len(result.table.rows)
else:
total_result_query = project.execute_macro('get_row_count_query', {"sql": query})
total_result_response = await loop.run_in_executor(
None, project.fn_threaded_conn(execute_function, total_result_query, ctx, True))
size = int(total_result_response.table.rows[0]['count'])
result.total_rows = size
except ExecuteSqlFailure as execution_err:
raise JinjatErrorContainer(
status_code=status.HTTP_400_BAD_REQUEST,
errors=[JinjatError(
code=JinjatErrorCode.ExecuteSqlFailure,
message=str(execution_err.dbt_exception),
error=execution_err.to_model()
)]
)
result = await loop.run_in_executor(
None, project.fn_threaded_conn(execute_function, final_query, ctx, fetch)
)
if include_total:
if limit is not None and len(result.table.rows) < limit:
size = len(result.table.rows)
else:
total_result_query = project.execute_macro('get_row_count_query', {"sql": query})
total_result_response = await loop.run_in_executor(
None, project.fn_threaded_conn(execute_function, total_result_query, ctx, True))
size = int(total_result_response.table.rows[0][0])
result.total_rows = size

return result

Expand Down
Loading

0 comments on commit 5513248

Please sign in to comment.