diff --git a/openbb_platform/core/integration/test_obbject.py b/openbb_platform/core/integration/test_obbject.py index 735dc5209696..7db3e4bc4769 100644 --- a/openbb_platform/core/integration/test_obbject.py +++ b/openbb_platform/core/integration/test_obbject.py @@ -89,3 +89,13 @@ def test_show(obb): stocks_data = obb.equity.price.historical("AAPL", provider="fmp", chart=True) assert isinstance(stocks_data.chart.fig, OpenBBFigure) assert stocks_data.chart.fig.show() is None + + +@pytest.mark.integration +def test_get_field_descriptions(obb): + """Test obbject get field descriptions.""" + + obb_data = obb.equity.profile("MSFT", provider="yfinance") + descriptions = obb_data.get_field_descriptions() + assert isinstance(descriptions, dict) + assert len(obb_data.to_df(index=None).columns) == len(descriptions) diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py index 8d4ba26839c5..da8ea5d4a4d9 100644 --- a/openbb_platform/core/openbb_core/app/model/obbject.py +++ b/openbb_platform/core/openbb_core/app/model/obbject.py @@ -116,6 +116,37 @@ def model_parametrized_name(cls, params: Any) -> str: """Return the model name with the parameters.""" return f"OBBject[{cls.results_type_repr(params)}]" + def get_field_descriptions(self) -> Dict[str, str]: + """ + Get a dictionary of the returned field keys with their descriptions. + + Fields automatically created by `alias_generator` will not have descriptions. + """ + descriptions = {} + model = None + if isinstance(self.results, list): + model = self.results[0].model_json_schema(by_alias=False).get("properties", None) # type: ignore + columns = self.to_df(index=None).columns.to_list() # type: ignore + if columns[0] == 0: + columns = self.to_df(index=None).iloc[:, 0].to_list() + else: + model = self.results.model_json_schema(by_alias=False).get("properties", None) # type: ignore + columns = list(self.results.model_dump(exclude_none=True).keys()) # type: ignore + if model is None: + raise OpenBBError( + "Could not extract model property definitions from OBBject." + ) + for i in columns: + if i in model: + descriptions[i] = ( + str(model[i].get("description", None)) + .replace(" ", "") + .replace("\n", " ") + .strip() + ) + + return descriptions + def to_df( self, index: Optional[Union[str, None]] = "date", sort_by: Optional[str] = None ) -> pd.DataFrame: diff --git a/openbb_platform/core/tests/app/model/test_obbject.py b/openbb_platform/core/tests/app/model/test_obbject.py index 431b2b56280d..49ac8ae0bc2d 100644 --- a/openbb_platform/core/tests/app/model/test_obbject.py +++ b/openbb_platform/core/tests/app/model/test_obbject.py @@ -1,5 +1,6 @@ """Tests for the OBBject class.""" +from typing import Optional from unittest.mock import MagicMock import pandas as pd @@ -8,6 +9,7 @@ from openbb_core.app.utils import basemodel_to_df from openbb_core.provider.abstract.data import Data from pandas.testing import assert_frame_equal +from pydantic import Field def test_OBBject(): @@ -56,6 +58,22 @@ class MockDataFrame(Data): value: float +class MockDataModel(Data): + + name: str = Field(default=None, description="Common name of the company.") + cik: str = Field( + default=None, + description="Central Index Key assigned to the company.", + ) + cusip: str = Field(default=None, description="CUSIP identifier for the company.") + isin: str = Field( + default=None, description="International Securities Identification Number." + ) + lei: Optional[str] = Field( + default=None, description="Legal Entity Identifier assigned to the company." + ) + + @pytest.mark.parametrize( "results, expected_df", [ @@ -388,3 +406,48 @@ def test_show_chart_no_fig(): # Act and Assert with pytest.raises(OpenBBError, match="Chart not found."): mock_instance.show() + + +@pytest.mark.parametrize( + "results", + [ + # Test case 1: List of models. + ( + [ + MockDataModel( + name="Mock Company", + cik="0001234567", + cusip="5556789", + isin="US5556789", + lei=None, + ), + MockDataModel( + name="Mock Company 2", + cik="0001234568", + cusip="5556781", + isin="US5556788", + lei="1234567890", + ), + ] + ), + # Test case 2: Not a list. + MockDataModel( + name="Mock Company 3", + cik="0001234565", + cusip="5556783", + isin="US5556785", + lei="1234567891", + ), + ], +) +def test_get_field_descriptions(results): + """Test helper.""" + # Arrange + co = OBBject(results=results) + + # Act + descriptions = co.get_field_descriptions() + + # Assert + assert isinstance(descriptions, dict) + assert len(descriptions) == 5