Skip to content

Commit

Permalink
add copilot
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Nov 24, 2023
1 parent 5513248 commit 77b4efc
Show file tree
Hide file tree
Showing 53 changed files with 5,297 additions and 1,813 deletions.
20 changes: 14 additions & 6 deletions src/jinjat/core/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import typing
from collections import OrderedDict
from copy import deepcopy
from typing import (
Any,
List,
Expand All @@ -9,6 +10,7 @@
)

import agate
import jsonref
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import ManifestNode
from fastapi.openapi.models import Parameter, Schema
Expand All @@ -18,10 +20,7 @@
from pydantic import BaseModel, validator
from starlette.requests import Request

from jinjat.core.util.api import JinjatErrorContainer, JinjatError

LIMIT_QUERY_PARAM = '_limit'
JSON_COLUMNS_QUERY_PARAM = '_json_columns'


class CORS(BaseModel):
Expand All @@ -37,12 +36,15 @@ class JinjatProjectConfig(BaseModel):

@validator('openapi')
def validate_openapi(cls, openapi):
info = openapi.get("info", {})
if "info" not in openapi:
openapi['info'] = {}
info = openapi.get("info")
if "title" in info is not None or "version" in info is not None:
raise ValueError(
"`openapi.info.title` and `openapi.info.version` must not be set as the values are derived from dbt_project.yml")

openapi['info']['title'] = openapi['info']['version'] = ""
info['title'] = info.get('title', '')
info['version'] = info.get('version', '')
OpenAPI.parse_obj(openapi)
return openapi

Expand Down Expand Up @@ -141,12 +143,18 @@ class JinjatExecutionResult(BaseModel):
@staticmethod
def from_dbt(ctx: DbtQueryRequestContext, result: DbtAdapterExecutionResult,
transform_response: typing.Callable[[dict], dict] = None,
json_columns: List[str] = []) -> 'JinjatExecutionResult':
response_schema: dict = None) -> 'JinjatExecutionResult':
columns = [JinjatColumn(name=column.name, type=column.data_type.__class__.__name__) for column in
result.table.columns]
adapter_response = JinjatAdapterResponse(message=result.adapter_response._message,
code=result.adapter_response.code,
rows_affected=result.adapter_response.rows_affected)
if response_schema.get('type') == 'object':
json_columns = [key for (key, value) in response_schema.get('properties', {}).items()
if value.get('type') in ['array', 'object']]
else:
json_columns = []

result_dict = _convert_table_to_dict(result.table, json_columns)
if transform_response is not None:
result_dict = transform_response(result_dict)
Expand Down
6 changes: 3 additions & 3 deletions src/jinjat/core/routes/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from jinjat.core.dbt.dbt_project import DbtProjectContainer, DbtProject
from jinjat.core.exceptions import ExecuteSqlFailure
from jinjat.core.models import JinjatExecutionResult, DbtAdapterExecutionResult, generate_dbt_context_from_request, \
DbtQueryRequestContext, JSON_COLUMNS_QUERY_PARAM
DbtQueryRequestContext
from jinjat.core.util.api import JinjatErrorContainer, JinjatError, JinjatErrorCode, DBT_PROJECT_HEADER, \
DBT_PROJECT_NAME, JSONAPIException
from jinjat.core.util.jmespath import extract_jmespath
Expand Down Expand Up @@ -92,9 +92,9 @@ async def execute_sql(
error=execution_err.to_model()
)]
)
json_cols = json.loads(request.query_params.get(JSON_COLUMNS_QUERY_PARAM) or '[]')

return JinjatExecutionResult.from_dbt(body.request, dbt_result, json_columns=json_cols)
# TODO: fix openapi_dict
return JinjatExecutionResult.from_dbt(body.request, dbt_result, openapi_dict={})


async def _execute_jinjat_query(project: DbtProject, execute_function, query: str, ctx: DbtQueryRequestContext,
Expand Down
62 changes: 42 additions & 20 deletions src/jinjat/core/routes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import json
import re
import sys
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Callable
Expand All @@ -23,8 +24,9 @@

from jinjat.core.dbt.dbt_project import DbtProject
from jinjat.core.exceptions import InvalidJinjaConfig, ExecuteSqlFailure
from jinjat.core.log_controller import logger
from jinjat.core.models import generate_dbt_context_from_request, JinjatExecutionResult, JinjatAnalysisConfig, \
JinjatProjectConfig, JSON_COLUMNS_QUERY_PARAM
JinjatProjectConfig
from jinjat.core.routes.admin import _execute_jinjat_query
from jinjat.core.schema.data_type import get_json_schema_from_data_type
from jinjat.core.util.api import get_human_readable_error, rapidoc_html, register_jsonapi_exception_handlers, \
Expand All @@ -37,12 +39,10 @@

async def handle_analysis_api(project: DbtProject,
sql: str,
operation: Operation,
openapi_dict: dict,
transform_request: Callable[[dict], dict],
transform_response: Callable[[dict], dict],
fetch: bool,
jinjat_project: JinjatProjectConfig,
request: Request,
response: Response):
context = await generate_dbt_context_from_request(request, openapi_dict, transform_request)
Expand All @@ -54,8 +54,9 @@ async def handle_analysis_api(project: DbtProject,
try:
query_result = await _execute_jinjat_query(project, project.execute_sql, sql, context,
limit, fetch, include_total=end is not None)
json_cols = json.loads(request.query_params.get(JSON_COLUMNS_QUERY_PARAM) or '[]')
jinjat_result = JinjatExecutionResult.from_dbt(context, query_result, transform_response, json_cols)
response_schema = openapi_dict.get("responses", {}).get(200, {}).get("content", {}).get("application/json",
{}).get("schema", {})
jinjat_result = JinjatExecutionResult.from_dbt(context, query_result, transform_response, response_schema)
if query_result.total_rows is not None:
response.headers['x-total-count'] = str(query_result.total_rows)
except ExecuteSqlFailure as execution_err:
Expand All @@ -76,11 +77,13 @@ async def handle_analysis_api(project: DbtProject,
return jinjat_result.data


def create_components_from_nodes(project: DbtProject, request: Request):
def create_components_from_nodes(project: DbtProject):
schema_nodes = filter(
lambda node: node.resource_type in ['model', 'seed', 'source', 'analysis'] and (
'jinjat' in node.meta or 'jinjat' in node.config),
project.dbt.nodes.values())
lambda node: node.resource_type in ['model', 'seed', 'source', 'analysis']
# and ('jinjat' in node.meta or 'jinjat' in node.config)
,
# https://stackoverflow.com/questions/11941817/how-can-i-avoid-runtimeerror-dictionary-changed-size-during-iteration-error
project.dbt.nodes.copy().values())

components = {}
for node in schema_nodes:
Expand Down Expand Up @@ -139,7 +142,14 @@ def get_final_response(transform: Optional[str], request_body_model: Optional[Sc
return Schema.parse_obj({"type": "array", "items": request_body_model})


async def custom_openapi(project, jinjat_project_config, api, package_name, req: Request) -> JSONResponse:
def generate_schema(project: DbtProject, openapi_schema):
component_schemas = create_components_from_nodes(project)
components = openapi_schema.setdefault('components', {})
existing_schemas = components.setdefault('schemas', {})
return {**existing_schemas, **component_schemas}


async def custom_openapi(project, jinjat_project_config, api, req: Request) -> JSONResponse:
extract_path = req.query_params.get("jmespath")
scheme = req.headers.get('x-forwarded-proto')
url = str(req.base_url.replace(scheme=scheme or req.url.scheme))
Expand All @@ -151,14 +161,11 @@ async def custom_openapi(project, jinjat_project_config, api, package_name, req:

openapi_schema = get_openapi(title=project.project_name,
version=project.config.version,
routes=api.routes,
servers=servers)
routes=api.routes)

component_schemas = create_components_from_nodes(project, req)
components = openapi_schema.setdefault('components', {})
existing_schemas = components.setdefault('schemas', {})
components['schemas'] = {**existing_schemas, **component_schemas}
openapi_schema["components"] = {"schemas": generate_schema(project, openapi_schema)}
openapi_schema["x-jinjat"] = {"refine": jinjat_project_config.refine}
openapi_schema['servers'] = servers

if jinjat_project_config.openapi is not None:
always_merger.merge(openapi_schema, jinjat_project_config.openapi)
Expand All @@ -185,7 +192,10 @@ def enrich_openapi_schema(project: DbtProject, openapi: Operation, config: Jinja
ref=f"#/components/schemas/{node.unique_id}")
))})}

openapi.parameters = (openapi.parameters or []) + (config.request.parameters or [])
params = (openapi.parameters or [])
if config.request is not None and config.request.parameters is not None:
params = params + config.request.parameters
openapi.parameters = params


def register_openapi_validators(project: DbtProject):
Expand Down Expand Up @@ -240,7 +250,7 @@ async def lookup_by_id(request: Request, response: Response):
# sub_app.add_route("/elements", functools.partial(elements_html, CustomButton("Admin APIs", "/")),
# include_in_schema=False)
sub_app.add_route(f"/{package_name}/openapi.json",
functools.partial(custom_openapi, project, jinjat_project_config, sub_app, package_name),
functools.partial(custom_openapi, project, jinjat_project_config, sub_app),
include_in_schema=True)

for node in analyses:
Expand Down Expand Up @@ -296,8 +306,20 @@ async def lookup_by_id(request: Request, response: Response):
raise InvalidJinjaConfig(node.original_file_path, None,
f"Unable to parse `transform_response` jmespath expression {jinjat_config.response.transform}: {e}")

endpoint = functools.partial(handle_analysis_api, project, node.raw_code, openapi, openapi_dict, transform_request,
transform_response, fetch_enabled, jinjat_project_config)
openapi_dict_resolved = deepcopy(openapi_dict)
schemas = generate_schema(project, {})
openapi_dict_resolved["components"] = {"schemas": schemas}
try:
openapi_dict_resolved = jsonref.replace_refs(openapi_dict_resolved, base_uri="", proxies=False,
lazy_load=False)
except jsonref.JsonRefError as e:
logger().error(
f"Error generating route {node.unique_id}\nOpenAPI schema validation failed: ${e.message}")
sys.exit(1)

endpoint = functools.partial(handle_analysis_api, project, node.raw_code, openapi_dict_resolved,
transform_request,
transform_response, fetch_enabled)
analysis_lookup[node.unique_id] = endpoint
sub_app.add_api_route(f'/{package_name}/{project.config.dependencies[package_name].version}/{api_path}',
endpoint=endpoint,
Expand Down
2 changes: 0 additions & 2 deletions src/jinjat/core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def homepage_without_ui(host, project: DbtProject, dbt_target: DbtTarget) -> dic


def mount_app(app: FastAPI, project: DbtProject, dbt_target: DbtTarget):
logger().info(f"start")
config = get_jinjat_project_config(project.project_root)

app.add_middleware(
Expand All @@ -114,7 +113,6 @@ def mount_app(app: FastAPI, project: DbtProject, dbt_target: DbtTarget):
allow_headers=["*"],
expose_headers=[DBT_PROJECT_HEADER, DBT_PROJECT_NAME]
)
logger().info(f"start2")
register_jsonapi_exception_handlers(app)
app.openapi = lambda: custom_openapi(project, config)

Expand Down
2 changes: 1 addition & 1 deletion src/jinjat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def serve(
port: int,
vars: str,
refine: Optional[bool] = False,
) -> object:
):
logger().info(f":water_wave: Executing jinjat for dbt project in {project_dir}")

dbt_target = DbtTarget(project_dir=project_dir, profiles_dir=profiles_dir, target=target,
Expand Down
3 changes: 2 additions & 1 deletion src/ui/.env.local
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
NEXT_PUBLIC_JINJAT_URL=http://127.0.0.1:8581
NEXT_PUBLIC_JINJAT_URL=http://127.0.0.1:8581
OPENAI_API_KEY=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
34 changes: 0 additions & 34 deletions src/ui/app/[...catchAll]/page.tsx

This file was deleted.

2 changes: 1 addition & 1 deletion src/ui/next.config.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const nextConfig = {
output: 'export',
// output: 'export',
distDir: 'dist',
pageExtensions: ['js', 'jsx', 'ts', 'tsx'],
webpack: (config) => {
Expand Down
Loading

0 comments on commit 77b4efc

Please sign in to comment.