diff --git a/vizro-ai/changelog.d/20240809_143449_lingyi_zhang_improve_code_style.md b/vizro-ai/changelog.d/20240809_143449_lingyi_zhang_improve_code_style.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-ai/changelog.d/20240809_143449_lingyi_zhang_improve_code_style.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py index 8814a515a..87f1494b3 100644 --- a/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py @@ -34,6 +34,8 @@ class ComponentPlan(BaseModel): component_id: str = Field( pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case description of this component." ) + # TODO: for improvement, we could dynamically create the pydantic model at runtime so that we can + # validate the df_name against the available dataframes df_name: str = Field( ..., description=""" diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py index 43bdcf62f..c86a21c62 100644 --- a/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py @@ -1,7 +1,7 @@ """Controls plan model.""" import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional, Type import pandas as pd import vizro.models as vm @@ -16,76 +16,68 @@ logger = logging.getLogger(__name__) -def _create_filter_proxy(df_cols, df_schema, controllable_components) -> BaseModel: - """Create a filter proxy model.""" +class FilterProxyModel: + """Filter proxy model.""" - def validate_targets(v): - """Validate the targets.""" - if v not in controllable_components: - raise ValueError(f"targets must be one of {controllable_components}") - return v + @classmethod + def _create_model( + cls, df_cols: List[str], df_schema: Dict[str, Any], controllable_components: List[str] + ) -> Type[BaseModel]: + def validate_targets(v): + if v not in controllable_components: + raise ValueError(f"targets must be one of {controllable_components}") + return v - def validate_targets_not_empty(v): - """Validate the targets not empty.""" - if not controllable_components: - raise ValueError( - """ - This might be due to the filter target is not found in the controllable components. - returning default values. - """ - ) - return v - - def validate_column(v): - """Validate the column.""" - if v not in df_cols: - raise ValueError(f"column must be one of {df_cols}") - return v - - @root_validator(allow_reuse=True) - def validate_date_picker_column(cls, values): - """Validate the column for date picker.""" - column = values.get("column") - selector = values.get("selector") - if selector and selector.type == "date_picker": - if not pd.api.types.is_datetime64_any_dtype(df_schema[column]): + def validate_targets_not_empty(v): + if not controllable_components: raise ValueError( - f""" - The column '{column}' is not of datetime type. Selector type 'date_picker' is - not allowed. Use 'dropdown' instead. + """ + This might be due to the filter target is not found in the controllable components. + returning default values. """ ) - return values - - return create_model( - "FilterProxy", - targets=( - List[str], - Field( - ..., - description=f""" - Target component to be affected by filter. - Must be one of {controllable_components}. ALWAYS REQUIRED. - """, + return v + + def validate_column(v): + if v not in df_cols: + raise ValueError(f"column must be one of {df_cols}") + return v + + @root_validator(allow_reuse=True) + def validate_date_picker_column(cls, values): + column = values.get("column") + selector = values.get("selector") + if selector and selector.type == "date_picker": + if not pd.api.types.is_datetime64_any_dtype(df_schema[column]): + raise ValueError( + f""" + The column '{column}' is not of datetime type. Selector type 'date_picker' is + not allowed. Use 'dropdown' instead. + """ + ) + return values + + return create_model( + "FilterProxy", + targets=( + List[str], + Field( + ..., + description=f""" + Target component to be affected by filter. + Must be one of {controllable_components}. ALWAYS REQUIRED. + """, + ), ), - ), - column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")), - __validators__={ - "validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets), - "validator2": validator("column", allow_reuse=True)(validate_column), - "validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty), - "validator4": validate_date_picker_column, - }, - __base__=vm.Filter, - ) - - -def _create_filter(filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter: - result_proxy = _create_filter_proxy( - df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components - ) - proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=result_proxy, df_info=df_schema) - return vm.Filter.parse_obj(proxy.dict(exclude_unset=True)) + column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")), + __validators__={ + "validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets), + "validator2": validator("column", allow_reuse=True)(validate_column), + "validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty), + "validator4": validate_date_picker_column, + }, + __base__=vm.Filter, + ) class ControlPlan(BaseModel): @@ -100,31 +92,58 @@ class ControlPlan(BaseModel): to control a specific component, include the relevant component details. """, ) - df_name: str = Field( + target_components_id: List[str] = Field( ..., description=""" - The name of the dataframe that the target component will use. - If the dataframe is not used, please specify that. + The id of the target components that this control will affect. """, ) - def create(self, model, controllable_components, all_df_metadata) -> Optional[vm.Filter]: + def _get_target_df_name(self, components_plan, controllable_components): + target_controllable = set(self.target_components_id) & set(controllable_components) + df_names = { + component_plan.df_name + for component_plan in components_plan + if component_plan.component_id in target_controllable + } + + if len(df_names) > 1: + logger.warning( + f""" +[FALLBACK] Multiple dataframes found in the target components: {df_names}. +Choose one dataframe to build the filter. +""" + ) + + return next(iter(df_names)) if df_names else None + + def _create_filter(self, filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter: + FilterProxy = FilterProxyModel._create_model( + df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components + ) + proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=FilterProxy, df_info=df_schema) + return vm.Filter.parse_obj(proxy.dict(exclude_unset=True)) + + def create(self, model, controllable_components, all_df_metadata, components_plan) -> Optional[vm.Filter]: """Create the control.""" filter_prompt = f""" Create a filter from the following instructions: <{self.control_description}>. Do not make up things that are optional and DO NOT configure actions, action triggers or action chains. If no options are specified, leave them out. """ + + df_name = self._get_target_df_name(components_plan, controllable_components) + try: - _df_schema = all_df_metadata.get_df_schema(self.df_name) + _df_schema = all_df_metadata.get_df_schema(df_name) _df_cols = list(_df_schema.keys()) except KeyError: - logger.warning(f"Dataframe {self.df_name} not found in metadata, returning default values.") + logger.warning(f"Dataframe {df_name} not found in metadata, returning default values.") return None try: if self.control_type == "Filter": - res = _create_filter( + res = self._create_filter( filter_prompt=filter_prompt, model=model, df_cols=_df_cols, @@ -147,7 +166,9 @@ def create(self, model, controllable_components, all_df_metadata) -> Optional[vm if __name__ == "__main__": import pandas as pd from dotenv import load_dotenv + from vizro.tables import dash_ag_grid from vizro_ai._llm_models import _get_llm_model + from vizro_ai.dashboard._response_models.components import ComponentPlan from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata load_dotenv() @@ -155,16 +176,29 @@ def create(self, model, controllable_components, all_df_metadata) -> Optional[vm model = _get_llm_model() all_df_metadata = AllDfMetadata({}) - all_df_metadata.all_df_metadata["gdp_chart"] = DfMetadata( + all_df_metadata.all_df_metadata["world_gdp"] = DfMetadata( df_schema={"a": "int64", "b": "int64"}, df=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), df_sample=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), ) + components_plan = [ + ComponentPlan( + component_type="AgGrid", + component_description="Create a table that shows GDP data.", + component_id="gdp_table", + df_name="world_gdp", + ) + ] + vm.AgGrid(id="gdp_table", figure=dash_ag_grid(data_frame="world_gdp")) control_plan = ControlPlan( control_type="Filter", control_description="Create a filter that filters the data by column 'a'.", - df_name="gdp_chart", + target_components_id=["gdp_table"], ) control = control_plan.create( - model, ["gdp_chart"], all_df_metadata - ) # error: Target gdp_chart not found in model_manager. + model, + ["gdp_table"], + all_df_metadata, + components_plan, + ) + print(control.__repr__()) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py index 7ecd33d42..6516285ca 100644 --- a/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py @@ -13,32 +13,6 @@ logger = logging.getLogger(__name__) -def _convert_to_grid(layout_grid_template_areas: List[str], component_ids: List[str]) -> List[List[int]]: - component_map = {component: index for index, component in enumerate(component_ids)} - grid = [] - - for row in layout_grid_template_areas: - grid_row = [] - for raw_cell in row.split(): - cell = raw_cell.strip("'\"") - if cell == ".": - grid_row.append(-1) - else: - try: - grid_row.append(component_map[cell]) - except KeyError: - logger.warning( - f""" -[FALLBACK] Component {cell} not found in component_ids: {component_ids}. -Returning default values. -""" - ) - return [] - grid.append(grid_row) - - return grid - - class LayoutPlan(BaseModel): """Layout plan model, which only applies to Vizro Components(Graph, AgGrid, Card).""" @@ -56,15 +30,38 @@ class LayoutPlan(BaseModel): """, ) + def _convert_to_grid(self, component_ids: List[str]) -> List[List[int]]: + component_map = {component: index for index, component in enumerate(component_ids)} + grid = [] + + for row in self.layout_grid_template_areas: + grid_row = [] + for raw_cell in row.split(): + cell = raw_cell.strip("'\"") + if cell == ".": + grid_row.append(-1) + else: + try: + grid_row.append(component_map[cell]) + except KeyError: + logger.warning( + f""" +[FALLBACK] Component {cell} not found in component_ids: {component_ids}. +Returning default values. +""" + ) + return [] + grid.append(grid_row) + + return grid + def create(self, component_ids: List[str]) -> Optional[vm.Layout]: """Create the layout.""" if not self.layout_grid_template_areas: return None try: - grid = _convert_to_grid( - layout_grid_template_areas=self.layout_grid_template_areas, component_ids=component_ids - ) + grid = self._convert_to_grid(component_ids=component_ids) actual = vm.Layout(grid=grid) except ValidationError as e: logger.warning( diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py index b1504c7f2..4efde4302 100644 --- a/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py @@ -155,6 +155,7 @@ def _build_controls(self, model, all_df_metadata): model=model, controllable_components=self._controllable_components(model=model, all_df_metadata=all_df_metadata), all_df_metadata=all_df_metadata, + components_plan=self.components_plan, ) if control: controls.append(control)