Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tidy] Improve markdown strip and add tests #861

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-vizro-whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
- main
paths:
- "vizro-core/**"
- "vizro-ai/examples/**"

defaults:
run:
Expand Down
18 changes: 13 additions & 5 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
CUSTOM_CHART_NAME = "custom_chart"


def _strip_markdown(code_string: str) -> str:
"""Strip markdown code block from the code string."""
prefixes = ["```python\n", "```py\n", "```\n"]
for prefix in prefixes:
if code_string.startswith(prefix):
code_string = code_string[len(prefix) :]
break
if code_string.endswith("```"):
code_string = code_string[: -len("```")]
return code_string.strip()


def _format_and_lint(code_string: str) -> str:
# Tracking https://github.com/astral-sh/ruff/issues/659 for proper Python API
# Good example: https://github.com/astral-sh/ruff/issues/8401#issuecomment-1788806462
Expand Down Expand Up @@ -93,11 +105,7 @@ class ChartPlan(BaseModel):

@validator("chart_code")
def _check_chart_code(cls, v):
# Remove markdown code block if present
if v.startswith("```python\n") and v.endswith("```"):
v = v[len("```python\n") : -3].strip()
elif v.startswith("```\n") and v.endswith("```"):
v = v[len("```\n") : -3].strip()
v = _strip_markdown(v)

# TODO: add more checks: ends with return, has return, no second function def, only one indented line
if f"def {CUSTOM_CHART_NAME}(" not in v:
Expand Down
43 changes: 37 additions & 6 deletions vizro-ai/tests/unit/vizro-ai/plot/test_response_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,48 @@ def test_check_chart_code_invalid(self, chart_code, error_type, error_message):
class TestChartPlanFactory:
"""Tests for the ChartPlanFactory class, mainly the execution of the chart code."""

def test_execute_chart_code_valid(self):
chart_plan_dynamic = ChartPlanFactory(data_frame=px.data.iris())
chart_plan_dynamic_valid = chart_plan_dynamic(
chart_type="Bubble Chart",
imports=["import plotly.express as px", "import numpy as np", "import random"],
chart_code="""def custom_chart(data_frame):
@pytest.mark.parametrize(
"chart_code",
[
"""def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
""",
"""def custom_chart(data_frame):
fig = px.scatter(data_frame, x='sepal_length', y='petal_length')
return fig
""",
"""```python
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
"""```py
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
"""```
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
],
)
def test_execute_chart_code_valid(self, chart_code):
chart_plan_dynamic = ChartPlanFactory(data_frame=px.data.iris())
chart_plan_dynamic_valid = chart_plan_dynamic(
chart_type="Bubble Chart",
imports=["import plotly.express as px", "import numpy as np", "import random"],
chart_code=chart_code,
chart_insights="Very good insights",
code_explanation="Very good explanation",
)
Expand Down
Loading