Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tidy] Improve markdown strip and add tests #861

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-vizro-whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
- main
paths:
- "vizro-core/**"
- "vizro-ai/examples/**"

defaults:
run:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
2 changes: 1 addition & 1 deletion vizro-ai/examples/dashboard_ui/requirements.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
gunicorn
vizro-ai>=0.3.0
vizro-ai>=0.3.2
black
openpyxl
langchain_anthropic
Expand Down
21 changes: 2 additions & 19 deletions vizro-ai/examples/dashboard_ui/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ anyio==4.4.0
# anthropic
# httpx
# openai
async-timeout==4.0.3
# via
# aiohttp
# langchain
attrs==24.2.0
# via aiohttp
autoflake==2.3.1
Expand Down Expand Up @@ -67,8 +63,6 @@ distro==1.9.0
# openai
et-xmlfile==1.1.0
# via openpyxl
exceptiongroup==1.2.2
# via anyio
filelock==3.16.1
# via huggingface-hub
flask==3.0.3
Expand All @@ -83,8 +77,6 @@ frozenlist==1.4.1
# aiosignal
fsspec==2024.10.0
# via huggingface-hub
greenlet==3.1.0
# via sqlalchemy
gunicorn==23.0.0
# via -r requirements.in
h11==0.14.0
Expand All @@ -108,9 +100,7 @@ idna==3.8
# requests
# yarl
importlib-metadata==8.5.0
# via
# dash
# flask
# via dash
itsdangerous==2.2.0
# via flask
jinja2==3.1.4
Expand Down Expand Up @@ -256,23 +246,16 @@ tokenizers==0.20.1
# via
# anthropic
# langchain-mistralai
tomli==2.0.1
# via
# autoflake
# black
tqdm==4.66.5
# via
# huggingface-hub
# openai
typing-extensions==4.12.2
# via
# anthropic
# anyio
# black
# dash
# huggingface-hub
# langchain-core
# multidict
# openai
# pydantic
# pydantic-core
Expand All @@ -283,7 +266,7 @@ urllib3==2.2.3
# via requests
vizro==0.1.23
# via vizro-ai
vizro-ai==0.3.0
vizro-ai==0.3.2
# via -r requirements.in
werkzeug==3.0.4
# via
Expand Down
18 changes: 13 additions & 5 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
CUSTOM_CHART_NAME = "custom_chart"


def _strip_markdown(code_string: str) -> str:
"""Strip markdown code block from the code string."""
prefixes = ["```python\n", "```py\n", "```\n"]
for prefix in prefixes:
if code_string.startswith(prefix):
code_string = code_string[len(prefix) :]
break
if code_string.endswith("```"):
code_string = code_string[: -len("```")]
return code_string.strip()


def _format_and_lint(code_string: str) -> str:
# Tracking https://github.com/astral-sh/ruff/issues/659 for proper Python API
# Good example: https://github.com/astral-sh/ruff/issues/8401#issuecomment-1788806462
Expand Down Expand Up @@ -93,11 +105,7 @@ class ChartPlan(BaseModel):

@validator("chart_code")
def _check_chart_code(cls, v):
# Remove markdown code block if present
if v.startswith("```python\n") and v.endswith("```"):
v = v[len("```python\n") : -3].strip()
elif v.startswith("```\n") and v.endswith("```"):
v = v[len("```\n") : -3].strip()
v = _strip_markdown(v)

# TODO: add more checks: ends with return, has return, no second function def, only one indented line
if f"def {CUSTOM_CHART_NAME}(" not in v:
Expand Down
43 changes: 37 additions & 6 deletions vizro-ai/tests/unit/vizro-ai/plot/test_response_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,48 @@ def test_check_chart_code_invalid(self, chart_code, error_type, error_message):
class TestChartPlanFactory:
"""Tests for the ChartPlanFactory class, mainly the execution of the chart code."""

def test_execute_chart_code_valid(self):
chart_plan_dynamic = ChartPlanFactory(data_frame=px.data.iris())
chart_plan_dynamic_valid = chart_plan_dynamic(
chart_type="Bubble Chart",
imports=["import plotly.express as px", "import numpy as np", "import random"],
chart_code="""def custom_chart(data_frame):
@pytest.mark.parametrize(
"chart_code",
[
"""def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
""",
"""def custom_chart(data_frame):
fig = px.scatter(data_frame, x='sepal_length', y='petal_length')
return fig
""",
"""```python
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
"""```py
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
"""```
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
],
)
def test_execute_chart_code_valid(self, chart_code):
chart_plan_dynamic = ChartPlanFactory(data_frame=px.data.iris())
chart_plan_dynamic_valid = chart_plan_dynamic(
chart_type="Bubble Chart",
imports=["import plotly.express as px", "import numpy as np", "import random"],
chart_code=chart_code,
chart_insights="Very good insights",
code_explanation="Very good explanation",
)
Expand Down