Skip to content

Commit

Permalink
feat: merge two rust functions be a class
Browse files Browse the repository at this point in the history
  • Loading branch information
grieve54706 committed Nov 29, 2024
1 parent 4961706 commit 441d724
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 42 deletions.
8 changes: 2 additions & 6 deletions ibis-server/app/mdl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ def get_session_context(
return wren_core.SessionContext(manifest_str, function_path)


def resolve_used_table_names(manifest_str: str, sql: str) -> list[str]:
return wren_core.resolve_used_table_names(manifest_str, sql)


def extract_manifest(manifest_str: str, datasets: list[str]) -> dict:
return wren_core.extract_manifest(manifest_str, datasets)
def get_extractor(manifest_str: str) -> wren_core.Extractor:
return wren_core.Extractor(manifest_str)


def to_json_base64(manifest):
Expand Down
13 changes: 7 additions & 6 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from app.config import get_config
from app.mdl.core import (
extract_manifest,
get_extractor,
get_session_context,
resolve_used_table_names,
to_json_base64,
)
from app.model import InternalServerError, UnprocessableEntityError
Expand Down Expand Up @@ -64,8 +63,9 @@ def __init__(self, manifest_str: str):

def rewrite(self, sql: str) -> str:
try:
tables = resolve_used_table_names(self.manifest_str, sql)
manifest = extract_manifest(self.manifest_str, tables)
extractor = get_extractor(self.manifest_str)
tables = extractor.resolve_used_table_names(sql)
manifest = extractor.extract_manifest(tables)
manifest_str = to_json_base64(manifest)
r = httpx.request(
method="GET",
Expand All @@ -92,8 +92,9 @@ def __init__(self, manifest_str: str, function_path: str):

def rewrite(self, sql: str) -> str:
try:
tables = resolve_used_table_names(self.manifest_str, sql)
manifest = extract_manifest(self.manifest_str, tables)
extractor = get_extractor(self.manifest_str)
tables = extractor.resolve_used_table_names(sql)
manifest = extractor.extract_manifest(tables)
session_context = get_session_context(
to_json_base64(manifest), self.function_path
)
Expand Down
59 changes: 36 additions & 23 deletions wren-core-py/src/extractor.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,48 @@
use crate::errors::CoreError;
use crate::manifest::{to_manifest, PyManifest};
use pyo3::pyfunction;
use pyo3::{pyclass, pymethods};
use std::collections::HashSet;
use std::sync::Arc;
use wren_core::mdl::manifest::{Manifest, Model, Relationship, View};
use wren_core::mdl::manifest::{Model, Relationship, View};
use wren_core::mdl::WrenMDL;

/// parse the given SQL and return the list of used table name.
#[pyfunction]
#[pyo3(name = "resolve_used_table_names", signature = (mdl_base64, sql), text_signature = "(mdl_base64: str, sql: str)")]
pub fn py_resolve_used_table_names(
mdl_base64: &str,
sql: &str,
) -> Result<Vec<String>, CoreError> {
let manifest = to_manifest(mdl_base64)?;
resolve_used_table_names(manifest, sql)
#[pyclass]
#[derive(Clone)]
#[pyo3(name = "Extractor")]
pub struct PyExtractor {
mdl: Arc<WrenMDL>,
}

#[pymethods]
impl PyExtractor {
#[new]
pub fn new(mdl_base64: &str) -> Self {
let manifest = to_manifest(mdl_base64).unwrap();
let mdl = WrenMDL::new_ref(manifest);
Self { mdl }
}

/// parse the given SQL and return the list of used table name.
pub fn resolve_used_table_names(&self, sql: &str) -> Result<Vec<String>, CoreError> {
resolve_used_table_names(&self.mdl, sql)
}

/// Given a used dataset list, extract manifest by removing unused datasets.
/// If a model is related to another dataset, both datasets will be kept.
/// The relationship between of them will be kept as well.
/// A dataset could be model, view.
pub fn extract_manifest(
&self,
used_datasets: Vec<String>,
) -> Result<PyManifest, CoreError> {
extract_manifest(&self.mdl, used_datasets)
}
}

fn resolve_used_table_names(
manifest: Manifest,
mdl: &Arc<WrenMDL>,
sql: &str,
) -> Result<Vec<String>, CoreError> {
let mdl = WrenMDL::new_ref(manifest);
let ctx_state = wren_core::SessionContext::new().state();
ctx_state
.sql_to_statement(sql, "generic")
Expand All @@ -43,18 +64,10 @@ fn resolve_used_table_names(
})
}

/// Given a used dataset list, extract manifest by removing unused datasets.
/// If a model is related to another dataset, both datasets will be kept.
/// The relationship between of them will be kept as well.
/// A dataset could be model, view.
#[pyfunction]
#[pyo3(signature = (mdl_base64, used_datasets), text_signature = "(mdl_base64: str, used_datasets: list[str])")]
pub fn extract_manifest(
mdl_base64: &str,
mdl: &Arc<WrenMDL>,
used_datasets: Vec<String>,
) -> Result<PyManifest, CoreError> {
let manifest = to_manifest(mdl_base64)?;
let mdl = WrenMDL::new_ref(manifest);
let used_models = extract_models(&mdl, &used_datasets);
let (used_views, models_of_views) = extract_views(&mdl, &used_datasets);
let used_relationships = extract_relationships(&mdl, &used_datasets);
Expand Down Expand Up @@ -103,7 +116,7 @@ fn extract_views(
.iter()
.filter_map(|&dataset_name| {
mdl.get_view(dataset_name).and_then(|view| {
resolve_used_table_names(mdl.manifest.clone(), view.statement.as_str())
resolve_used_table_names(mdl, view.statement.as_str())
.ok()
.map(|used_tables| extract_models(mdl, &used_tables))
})
Expand Down
3 changes: 1 addition & 2 deletions wren-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<context::PySessionContext>()?;
m.add_class::<PyRemoteFunction>()?;
m.add_class::<manifest::PyManifest>()?;
m.add_function(wrap_pyfunction!(extractor::py_resolve_used_table_names, m)?)?;
m.add_function(wrap_pyfunction!(extractor::extract_manifest, m)?)?;
m.add_class::<extractor::PyExtractor>()?;
m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?;
Ok(())
}
9 changes: 4 additions & 5 deletions wren-core-py/tests/test_modeling_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import pytest
from wren_core import (
SessionContext,
resolve_used_table_names,
extract_manifest,
Extractor,
to_json_base64,
)

Expand Down Expand Up @@ -155,7 +154,7 @@ def test_get_available_functions():
],
)
def test_resolve_used_table_names(sql, expected):
tables = resolve_used_table_names(manifest_str, sql)
tables = Extractor(manifest_str).resolve_used_table_names(sql)
assert tables == expected


Expand All @@ -169,13 +168,13 @@ def test_resolve_used_table_names(sql, expected):
],
)
def test_extract_manifest(dataset, expected_models):
extracted_manifest = extract_manifest(manifest_str, dataset)
extracted_manifest = Extractor(manifest_str).extract_manifest(dataset)
assert len(extracted_manifest.models) == len(expected_models)
assert [m.name for m in extracted_manifest.models] == expected_models


def test_to_json_base64():
extracted_manifest = extract_manifest(manifest_str, ["customer"])
extracted_manifest = Extractor(manifest_str).extract_manifest(["customer"])
base64_str = to_json_base64(extracted_manifest)
with does_not_raise():
json_str = base64.b64decode(base64_str)
Expand Down

0 comments on commit 441d724

Please sign in to comment.