Skip to content

Commit

Permalink
[Tidy] Improve markdown strip and add tests (#861)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
maxschulz-COL and pre-commit-ci[bot] authored Nov 11, 2024
1 parent 91b67f1 commit 02e6419
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 31 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-vizro-whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
- main
paths:
- "vizro-core/**"
- "vizro-ai/examples/**"

defaults:
run:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
2 changes: 1 addition & 1 deletion vizro-ai/examples/dashboard_ui/requirements.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
gunicorn
vizro-ai>=0.3.0
vizro-ai>=0.3.2
black
openpyxl
langchain_anthropic
Expand Down
21 changes: 2 additions & 19 deletions vizro-ai/examples/dashboard_ui/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ anyio==4.4.0
# anthropic
# httpx
# openai
async-timeout==4.0.3
# via
# aiohttp
# langchain
attrs==24.2.0
# via aiohttp
autoflake==2.3.1
Expand Down Expand Up @@ -67,8 +63,6 @@ distro==1.9.0
# openai
et-xmlfile==1.1.0
# via openpyxl
exceptiongroup==1.2.2
# via anyio
filelock==3.16.1
# via huggingface-hub
flask==3.0.3
Expand All @@ -83,8 +77,6 @@ frozenlist==1.4.1
# aiosignal
fsspec==2024.10.0
# via huggingface-hub
greenlet==3.1.0
# via sqlalchemy
gunicorn==23.0.0
# via -r requirements.in
h11==0.14.0
Expand All @@ -108,9 +100,7 @@ idna==3.8
# requests
# yarl
importlib-metadata==8.5.0
# via
# dash
# flask
# via dash
itsdangerous==2.2.0
# via flask
jinja2==3.1.4
Expand Down Expand Up @@ -256,23 +246,16 @@ tokenizers==0.20.1
# via
# anthropic
# langchain-mistralai
tomli==2.0.1
# via
# autoflake
# black
tqdm==4.66.5
# via
# huggingface-hub
# openai
typing-extensions==4.12.2
# via
# anthropic
# anyio
# black
# dash
# huggingface-hub
# langchain-core
# multidict
# openai
# pydantic
# pydantic-core
Expand All @@ -283,7 +266,7 @@ urllib3==2.2.3
# via requests
vizro==0.1.23
# via vizro-ai
vizro-ai==0.3.0
vizro-ai==0.3.2
# via -r requirements.in
werkzeug==3.0.4
# via
Expand Down
18 changes: 13 additions & 5 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
CUSTOM_CHART_NAME = "custom_chart"


def _strip_markdown(code_string: str) -> str:
"""Strip markdown code block from the code string."""
prefixes = ["```python\n", "```py\n", "```\n"]
for prefix in prefixes:
if code_string.startswith(prefix):
code_string = code_string[len(prefix) :]
break
if code_string.endswith("```"):
code_string = code_string[: -len("```")]
return code_string.strip()


def _format_and_lint(code_string: str) -> str:
# Tracking https://github.com/astral-sh/ruff/issues/659 for proper Python API
# Good example: https://github.com/astral-sh/ruff/issues/8401#issuecomment-1788806462
Expand Down Expand Up @@ -93,11 +105,7 @@ class ChartPlan(BaseModel):

@validator("chart_code")
def _check_chart_code(cls, v):
# Remove markdown code block if present
if v.startswith("```python\n") and v.endswith("```"):
v = v[len("```python\n") : -3].strip()
elif v.startswith("```\n") and v.endswith("```"):
v = v[len("```\n") : -3].strip()
v = _strip_markdown(v)

# TODO: add more checks: ends with return, has return, no second function def, only one indented line
if f"def {CUSTOM_CHART_NAME}(" not in v:
Expand Down
43 changes: 37 additions & 6 deletions vizro-ai/tests/unit/vizro-ai/plot/test_response_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,48 @@ def test_check_chart_code_invalid(self, chart_code, error_type, error_message):
class TestChartPlanFactory:
"""Tests for the ChartPlanFactory class, mainly the execution of the chart code."""

def test_execute_chart_code_valid(self):
chart_plan_dynamic = ChartPlanFactory(data_frame=px.data.iris())
chart_plan_dynamic_valid = chart_plan_dynamic(
chart_type="Bubble Chart",
imports=["import plotly.express as px", "import numpy as np", "import random"],
chart_code="""def custom_chart(data_frame):
@pytest.mark.parametrize(
"chart_code",
[
"""def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
""",
"""def custom_chart(data_frame):
fig = px.scatter(data_frame, x='sepal_length', y='petal_length')
return fig
""",
"""```python
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
"""```py
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
"""```
def custom_chart(data_frame):
random_other_module_import = np.arange(10)
other_random_module_import = random.sample(range(10), 10)
fig = px.scatter(data_frame, x='sepal_width', y='petal_width')
return fig
```""",
],
)
def test_execute_chart_code_valid(self, chart_code):
chart_plan_dynamic = ChartPlanFactory(data_frame=px.data.iris())
chart_plan_dynamic_valid = chart_plan_dynamic(
chart_type="Bubble Chart",
imports=["import plotly.express as px", "import numpy as np", "import random"],
chart_code=chart_code,
chart_insights="Very good insights",
code_explanation="Very good explanation",
)
Expand Down

0 comments on commit 02e6419

Please sign in to comment.