Skip to content

Commit

Permalink
feat(wren-ai-service): Add Relationship Type Validation and Language …
Browse files Browse the repository at this point in the history
…Support for semantics description (#931)
  • Loading branch information
paopa authored Nov 20, 2024
1 parent bc1a5a6 commit b66b2c0
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -33,8 +34,9 @@ def column_filter(columns: list[dict]) -> list[dict]:
def prompt(
cleaned_models: dict,
prompt_builder: PromptBuilder,
language: str,
) -> dict:
return prompt_builder.run(models=cleaned_models)
return prompt_builder.run(models=cleaned_models, language=language)


@observe(as_type="generation", capture_input=False)
Expand All @@ -61,18 +63,38 @@ def wrapper(text: str) -> str:
return normalized


@observe(capture_input=False)
def validated(normalized: dict, engine: Engine) -> dict:
relationships = normalized.get("relationships", [])

validated_relationships = [
relationship
for relationship in relationships
if RelationType.is_include(relationship.get("type"))
]

# todo: after wren-engine support function to validate the relationships, we will use that function to validate the relationships
# for now, we will just return the normalized relationships
return normalized

return {"relationships": validated_relationships}


## End of Pipeline
class RelationType(Enum):
MANY_TO_ONE = "MANY_TO_ONE"
ONE_TO_MANY = "ONE_TO_MANY"
ONE_TO_ONE = "ONE_TO_ONE"

@classmethod
def is_include(cls, value: str) -> bool:
return value in cls._value2member_map_


class ModelRelationship(BaseModel):
name: str
fromModel: str
fromColumn: str
type: str
type: RelationType
toModel: str
toColumn: str
reason: str
Expand All @@ -97,7 +119,7 @@ class RelationshipResult(BaseModel):
- **name**: A descriptive name for the relationship.
- **fromModel**: The name of the source model.
- **fromColumn**: The column in the source model that forms the relationship.
- **type**: The type of relationship, which can be MANY_TO_ONE, ONE_TO_MANY or ONE_TO_ONE.
- **type**: The type of relationship, which can be "MANY_TO_ONE", "ONE_TO_MANY" or "ONE_TO_ONE" only.
- **toModel**: The name of the target model.
- **toColumn**: The column in the target model that forms the relationship.
- **reason**: The reason for recommending this relationship.
Expand All @@ -106,6 +128,7 @@ class RelationshipResult(BaseModel):
1. Do not recommend relationships within the same model (fromModel and toModel must be different).
2. Only suggest relationships if there is a clear and beneficial reason to do so.
3. If there are no good relationships to recommend or if there are fewer than two models, return an empty list of relationships.
4. Use "MANY_TO_ONE" and "ONE_TO_MANY" instead of "MANY_TO_MANY" relationships.
Output all relationships in the following JSON structure:
Expand All @@ -132,13 +155,14 @@ class RelationshipResult(BaseModel):
"""

user_prompt_template = """
Here is my data model's relationship specification:
Here is the relationship specification for my data model:
{{models}}
**Please analyze these models and suggest optimizations for their relationships.**
Take into account best practices in database design, opportunities for normalization, indexing strategies, and any additional relationships that could improve data integrity and enhance query performance.
**Please review these models and provide recommendations of relationship to optimize them.**
Consider best practices in database design, potential normalization opportunities, indexing strategies, and any additional relationships that might enhance data integrity and query performance.
Use this for the relationship name and reason: {{language}}
"""


Expand Down Expand Up @@ -167,6 +191,7 @@ def __init__(
def visualize(
self,
mdl: dict,
language: str = "English",
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Expand All @@ -177,6 +202,7 @@ def visualize(
output_file_path=f"{destination}/relationship_recommendation.dot",
inputs={
"mdl": mdl,
"language": language,
**self._components,
},
show_legend=True,
Expand All @@ -187,12 +213,14 @@ def visualize(
async def run(
self,
mdl: dict,
language: str = "English",
) -> dict:
logger.info("Relationship Recommendation pipeline is running...")
return await self._pipe.execute(
[self._final],
inputs={
"mdl": mdl,
"language": language,
**self._components,
},
)
Expand All @@ -201,23 +229,22 @@ async def run(
if __name__ == "__main__":
from langfuse.decorators import langfuse_context

from src.core.engine import EngineConfig
from src.config import settings
from src.core.pipeline import async_validate
from src.providers import init_providers
from src.utils import init_langfuse, load_env_vars
from src.providers import generate_components
from src.utils import init_langfuse

load_env_vars()
pipe_components = generate_components(settings.components)
pipeline = RelationshipRecommendation(
**pipe_components["relationship_recommendation"]
)
init_langfuse()

llm_provider, _, _, engine = init_providers(EngineConfig())
pipeline = RelationshipRecommendation(llm_provider=llm_provider, engine=engine)

with open("sample/college_3_bigquery_mdl.json", "r") as file:
with open("sample/woocommerce_bigquery_mdl.json", "r") as file:
mdl = json.load(file)

input = {"mdl": mdl}
input = {"mdl": mdl, "language": "Traditional Chinese"}

pipeline.visualize(**input)
async_validate(lambda: pipeline.run(**input))

langfuse_context.flush()
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_service_container,
get_service_metadata,
)
from src.web.v1.services import Configuration
from src.web.v1.services.relationship_recommendation import RelationshipRecommendation

router = APIRouter()
Expand All @@ -26,7 +27,10 @@
- Request body: PostRequest
{
"mdl": "{ ... }", # JSON string of the MDL (Model Definition Language)
"project_id": "project-id" # Optional project ID
"project_id": "project-id", # Optional project ID
"configuration": { # Optional configuration settings
"language": "English", # Language for the recommendation
}
}
- Response: PostResponse
{
Expand Down Expand Up @@ -64,6 +68,7 @@
class PostRequest(BaseModel):
mdl: str
project_id: Optional[str] = None
configuration: Optional[Configuration] = Configuration()


class PostResponse(BaseModel):
Expand All @@ -87,6 +92,8 @@ async def recommend(
input = RelationshipRecommendation.Input(
id=id,
mdl=request.mdl,
project_id=request.project_id,
configuration=request.configuration,
)

background_tasks.add_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def generate(
user_prompt=request.user_prompt,
mdl=request.mdl,
configuration=request.configuration,
project_id=request.project_id,
)

background_tasks.add_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from src.core.pipeline import BasicPipeline
from src.utils import trace_metadata
from src.web.v1.services import MetadataTraceable
from src.web.v1.services import Configuration, MetadataTraceable

logger = logging.getLogger("wren-ai-service")

Expand All @@ -17,6 +17,8 @@ class RelationshipRecommendation:
class Input(BaseModel):
id: str
mdl: str
project_id: Optional[str] = None # this is for tracing purpose
configuration: Optional[Configuration] = Configuration()

class Resource(BaseModel, MetadataTraceable):
class Error(BaseModel):
Expand Down Expand Up @@ -62,6 +64,7 @@ async def recommend(self, request: Input, **kwargs) -> Resource:

input = {
"mdl": mdl_dict,
"language": request.configuration.language,
}

resp = await self._pipelines["relationship_recommendation"].run(**input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Input(BaseModel):
user_prompt: str
mdl: str
configuration: Optional[Configuration] = Configuration()
project_id: Optional[str] = None # this is for tracing purpose

class Resource(BaseModel, MetadataTraceable):
class Error(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def test_recommend_success(relationship_recommendation_service, mock_pipel
assert response.id == "test_id"
assert response.status == "finished"
assert response.response == {"test": "data"}
mock_pipeline.run.assert_called_once_with(mdl={"key": "value"})
mock_pipeline.run.assert_called_once_with(mdl={"key": "value"}, language="English")


@pytest.mark.asyncio
Expand Down

0 comments on commit b66b2c0

Please sign in to comment.