Skip to content

Commit

Permalink
[Feat] Add xAI grok-beta to code (#858)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lingyielia and pre-commit-ci[bot] authored Nov 7, 2024
1 parent 7596cd9 commit 1dee965
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 4 deletions.
48 changes: 48 additions & 0 deletions vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
10 changes: 9 additions & 1 deletion vizro-ai/examples/dashboard_ui/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled

SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI}
SUPPORTED_VENDORS = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI": ChatOpenAI,
}

SUPPORTED_MODELS = {
"OpenAI": [
Expand All @@ -43,6 +48,7 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}
DEFAULT_TEMPERATURE = 0.1
DEFAULT_RETRY = 3
Expand All @@ -62,6 +68,8 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
)
if vendor_input == "Mistral":
llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
if vendor_input == "xAI":
llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE)

vizro_ai = VizroAI(model=llm)
ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
Expand Down
7 changes: 6 additions & 1 deletion vizro-ai/examples/dashboard_ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}


Expand Down Expand Up @@ -180,7 +181,11 @@
MyDropdown(
options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id"
),
OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral"], value="OpenAI"),
OffCanvas(
id="settings",
options=["OpenAI", "Anthropic", "Mistral", "xAI"],
value="OpenAI",
),
UserPromptTextArea(id="text-area-id"),
# Modal(id="modal"),
],
Expand Down
10 changes: 10 additions & 0 deletions vizro-ai/examples/dashboard_ui/assets/custom_css.css
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,13 @@
#open-settings-id:hover {
cursor: pointer;
}

.hover-effect {
transition: all 0.2s ease !important;
}

.hover-effect:hover {
background-color: rgba(255, 255, 255, 0.1) !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
transform: translateY(-2px);
}
59 changes: 58 additions & 1 deletion vizro-ai/examples/dashboard_ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,33 @@ def build(self):
)


def create_provider_item(name, url, note=None):
"""Helper function to create a consistent ListGroupItem for each provider."""
return dbc.ListGroupItem(
[
html.Div(
[
html.Span(name, style={"color": "#ffffff"}),
(html.Small(note, style={"color": "rgba(255, 255, 255, 0.5)"}) if note else None),
html.Span("→", className="float-end", style={"color": "#ffffff"}),
],
className="d-flex justify-content-between align-items-center",
)
],
href=url,
target="_blank",
action=True,
style={
"background-color": "transparent",
"border": "1px solid rgba(255, 255, 255, 0.1)",
"margin-bottom": "8px",
"transition": "all 0.2s ease",
"cursor": "pointer",
},
class_name="list-group-item-action hover-effect",
)


class OffCanvas(vm.VizroBaseModel):
"""OffCanvas component for settings."""

Expand Down Expand Up @@ -202,14 +229,44 @@ def build(self):
className="mb-3",
)

providers = [
{"name": "OpenAI", "url": "https://openai.com/index/openai-api/"},
{"name": "Anthropic", "url": "https://docs.anthropic.com/en/api/getting-started"},
{"name": "Mistral", "url": "https://docs.mistral.ai/getting-started/quickstart/"},
{"name": "xAI", "url": "https://x.ai/blog/api", "note": "(Free API credits available)"},
]

api_instructions = html.Div(
[
html.Hr(
style={
"margin": "2rem 0",
"border-color": "rgba(255, 255, 255, 0.1)",
"border-style": "solid",
"border-width": "0 0 1px 0",
}
),
html.Div("Get API Keys", className="mb-3", style={"color": "#ffffff"}),
dbc.ListGroup(
[
create_provider_item(name=provider["name"], url=provider["url"], note=provider.get("note"))
for provider in providers
],
flush=True,
className="border-0",
),
],
)

offcanvas = dbc.Offcanvas(
id=self.id,
children=[
html.Div(
children=[
input_groups,
api_instructions,
]
)
),
],
title="Settings",
is_open=True,
Expand Down
8 changes: 8 additions & 0 deletions vizro-ai/examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
"# llm = \"claude-3-5-sonnet-latest\"\n",
"# llm = \"mistral-large-latest\"\n",
"\n",
"# llm = \"grok-beta\" #xAI API is compatible with OpenAI. To use grok-beta,\n",
"# point `OPENAI_BASE_URL` to the xAI baseurl, use xAI API key for `OPENAI_API_KEY`\n",
"# when setting up the environment variables\n",
"# e.g.\n",
"# OPENAI_BASE_URL=\"https://api.x.ai/v1\"\n",
"# OPENAI_API_KEY=<xAI API key>\n",
"# reference: https://docs.x.ai/api/integrations#openai-sdk\n",
"\n",
"# from langchain_openai import ChatOpenAI\n",
"# llm = ChatOpenAI(\n",
"# model=\"gpt-4o\")\n",
Expand Down
2 changes: 2 additions & 0 deletions vizro-ai/src/vizro_ai/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
"claude-3-haiku-20240307",
],
"Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
"xAI": ["grok-beta"],
}

DEFAULT_WRAPPER_MAP: dict[str, BaseChatModel] = {
"OpenAI": ChatOpenAI,
"Anthropic": ChatAnthropic,
"Mistral": ChatMistralAI,
"xAI": ChatOpenAI, # xAI API is compatible with OpenAI
}

DEFAULT_MODEL = "gpt-4o-mini"
Expand Down
9 changes: 8 additions & 1 deletion vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,18 @@ class ChartPlan(BaseModel):

@validator("chart_code")
def _check_chart_code(cls, v):
# Remove markdown code block if present
if v.startswith("```python\n") and v.endswith("```"):
v = v[len("```python\n") : -3].strip()
elif v.startswith("```\n") and v.endswith("```"):
v = v[len("```\n") : -3].strip()

# TODO: add more checks: ends with return, has return, no second function def, only one indented line
if f"def {CUSTOM_CHART_NAME}(" not in v:
raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`")

if "data_frame" not in v.split("\n")[0]:
first_line = v.split("\n")[0].strip()
if "data_frame" not in first_line:
raise ValueError(
"""The chart code must accept a single argument `data_frame`,
and it should be the first argument of the chart."""
Expand Down

0 comments on commit 1dee965

Please sign in to comment.