diff --git a/vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md b/vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-ai/changelog.d/20241107_112343_lingyi_zhang_xai.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-ai/examples/dashboard_ui/actions.py b/vizro-ai/examples/dashboard_ui/actions.py index a69893d59..872ce0274 100644 --- a/vizro-ai/examples/dashboard_ui/actions.py +++ b/vizro-ai/examples/dashboard_ui/actions.py @@ -28,7 +28,12 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled -SUPPORTED_VENDORS = {"OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI} +SUPPORTED_VENDORS = { + "OpenAI": ChatOpenAI, + "Anthropic": ChatAnthropic, + "Mistral": ChatMistralAI, + "xAI": ChatOpenAI, +} SUPPORTED_MODELS = { "OpenAI": [ @@ -43,6 +48,7 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], + "xAI": ["grok-beta"], } DEFAULT_TEMPERATURE = 0.1 DEFAULT_RETRY = 3 @@ -62,6 +68,8 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input): ) if vendor_input == "Mistral": llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE) + if vendor_input == "xAI": + llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE) vizro_ai = VizroAI(model=llm) ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True) diff --git a/vizro-ai/examples/dashboard_ui/app.py b/vizro-ai/examples/dashboard_ui/app.py index 2a4d00752..e8b29f7dc 100644 --- a/vizro-ai/examples/dashboard_ui/app.py +++ b/vizro-ai/examples/dashboard_ui/app.py @@ -70,6 +70,7 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], + "xAI": ["grok-beta"], } @@ -180,7 +181,11 @@ MyDropdown( options=SUPPORTED_MODELS["OpenAI"], value="gpt-4o-mini", multi=False, id="model-dropdown-id" ), - OffCanvas(id="settings", options=["OpenAI", "Anthropic", "Mistral"], value="OpenAI"), + OffCanvas( + id="settings", + options=["OpenAI", "Anthropic", "Mistral", "xAI"], + value="OpenAI", + ), UserPromptTextArea(id="text-area-id"), # Modal(id="modal"), ], diff --git a/vizro-ai/examples/dashboard_ui/assets/custom_css.css b/vizro-ai/examples/dashboard_ui/assets/custom_css.css index cd9c92d14..dea230d73 100644 --- a/vizro-ai/examples/dashboard_ui/assets/custom_css.css +++ b/vizro-ai/examples/dashboard_ui/assets/custom_css.css @@ -306,3 +306,13 @@ #open-settings-id:hover { cursor: pointer; } + +.hover-effect { + transition: all 0.2s ease !important; +} + +.hover-effect:hover { + background-color: rgba(255, 255, 255, 0.1) !important; + box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); + transform: translateY(-2px); +} diff --git a/vizro-ai/examples/dashboard_ui/components.py b/vizro-ai/examples/dashboard_ui/components.py index 41a3d6d88..1e8397b3f 100644 --- a/vizro-ai/examples/dashboard_ui/components.py +++ b/vizro-ai/examples/dashboard_ui/components.py @@ -148,6 +148,33 @@ def build(self): ) +def create_provider_item(name, url, note=None): + """Helper function to create a consistent ListGroupItem for each provider.""" + return dbc.ListGroupItem( + [ + html.Div( + [ + html.Span(name, style={"color": "#ffffff"}), + (html.Small(note, style={"color": "rgba(255, 255, 255, 0.5)"}) if note else None), + html.Span("→", className="float-end", style={"color": "#ffffff"}), + ], + className="d-flex justify-content-between align-items-center", + ) + ], + href=url, + target="_blank", + action=True, + style={ + "background-color": "transparent", + "border": "1px solid rgba(255, 255, 255, 0.1)", + "margin-bottom": "8px", + "transition": "all 0.2s ease", + "cursor": "pointer", + }, + class_name="list-group-item-action hover-effect", + ) + + class OffCanvas(vm.VizroBaseModel): """OffCanvas component for settings.""" @@ -202,14 +229,44 @@ def build(self): className="mb-3", ) + providers = [ + {"name": "OpenAI", "url": "https://openai.com/index/openai-api/"}, + {"name": "Anthropic", "url": "https://docs.anthropic.com/en/api/getting-started"}, + {"name": "Mistral", "url": "https://docs.mistral.ai/getting-started/quickstart/"}, + {"name": "xAI", "url": "https://x.ai/blog/api", "note": "(Free API credits available)"}, + ] + + api_instructions = html.Div( + [ + html.Hr( + style={ + "margin": "2rem 0", + "border-color": "rgba(255, 255, 255, 0.1)", + "border-style": "solid", + "border-width": "0 0 1px 0", + } + ), + html.Div("Get API Keys", className="mb-3", style={"color": "#ffffff"}), + dbc.ListGroup( + [ + create_provider_item(name=provider["name"], url=provider["url"], note=provider.get("note")) + for provider in providers + ], + flush=True, + className="border-0", + ), + ], + ) + offcanvas = dbc.Offcanvas( id=self.id, children=[ html.Div( children=[ input_groups, + api_instructions, ] - ) + ), ], title="Settings", is_open=True, diff --git a/vizro-ai/examples/example.ipynb b/vizro-ai/examples/example.ipynb index 9fc56071c..f8029a472 100644 --- a/vizro-ai/examples/example.ipynb +++ b/vizro-ai/examples/example.ipynb @@ -14,6 +14,14 @@ "# llm = \"claude-3-5-sonnet-latest\"\n", "# llm = \"mistral-large-latest\"\n", "\n", + "# llm = \"grok-beta\" #xAI API is compatible with OpenAI. To use grok-beta,\n", + "# point `OPENAI_BASE_URL` to the xAI baseurl, use xAI API key for `OPENAI_API_KEY`\n", + "# when setting up the environment variables\n", + "# e.g.\n", + "# OPENAI_BASE_URL=\"https://api.x.ai/v1\"\n", + "# OPENAI_API_KEY=\n", + "# reference: https://docs.x.ai/api/integrations#openai-sdk\n", + "\n", "# from langchain_openai import ChatOpenAI\n", "# llm = ChatOpenAI(\n", "# model=\"gpt-4o\")\n", diff --git a/vizro-ai/src/vizro_ai/_llm_models.py b/vizro-ai/src/vizro_ai/_llm_models.py index c3c9858b0..6b572d512 100644 --- a/vizro-ai/src/vizro_ai/_llm_models.py +++ b/vizro-ai/src/vizro_ai/_llm_models.py @@ -27,12 +27,14 @@ "claude-3-haiku-20240307", ], "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], + "xAI": ["grok-beta"], } DEFAULT_WRAPPER_MAP: dict[str, BaseChatModel] = { "OpenAI": ChatOpenAI, "Anthropic": ChatAnthropic, "Mistral": ChatMistralAI, + "xAI": ChatOpenAI, # xAI API is compatible with OpenAI } DEFAULT_MODEL = "gpt-4o-mini" diff --git a/vizro-ai/src/vizro_ai/plot/_response_models.py b/vizro-ai/src/vizro_ai/plot/_response_models.py index 5bdedb73e..efa10259d 100644 --- a/vizro-ai/src/vizro_ai/plot/_response_models.py +++ b/vizro-ai/src/vizro_ai/plot/_response_models.py @@ -93,11 +93,18 @@ class ChartPlan(BaseModel): @validator("chart_code") def _check_chart_code(cls, v): + # Remove markdown code block if present + if v.startswith("```python\n") and v.endswith("```"): + v = v[len("```python\n") : -3].strip() + elif v.startswith("```\n") and v.endswith("```"): + v = v[len("```\n") : -3].strip() + # TODO: add more checks: ends with return, has return, no second function def, only one indented line if f"def {CUSTOM_CHART_NAME}(" not in v: raise ValueError(f"The chart code must be wrapped in a function named `{CUSTOM_CHART_NAME}`") - if "data_frame" not in v.split("\n")[0]: + first_line = v.split("\n")[0].strip() + if "data_frame" not in first_line: raise ValueError( """The chart code must accept a single argument `data_frame`, and it should be the first argument of the chart."""