diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index e8be2dd1..2e8812dc 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -988,7 +988,7 @@ def save_as( """, has_results=False, ) - return DataFrame.from_table(table_name, self._db, schema=schema) + return DataFrame.from_table(table_name, self._db, schema=schema if not temp else "pg_temp") def create_index( self, @@ -1264,3 +1264,25 @@ def from_files(cls, files: list[str], parser: "NormalFunction", db: Database) -> raise NotImplementedError( "Please import greenplumpython.experimental.file to load the implementation." ) + + def describe(self) -> dict[str, str]: + """ + Return a dictionary summarising the column information of the dataframe, conditional on the table existing in the database. + + Returns: + Dictionary containing the column names and types. + + """ + assert self._qualified_table_name is not None, f"Dataframe is not saved in database." + columns_query = f""" + SELECT attname AS column_name, atttypid::regtype AS data_type + FROM pg_attribute + WHERE attrelid = '{self._qualified_table_name}'::regclass and attnum > 0; + """ + assert self._db is not None + columns_inf_result = list(self._db._execute(columns_query, has_results=True)) # type: ignore reportUnknownVariableType + assert columns_inf_result, f"Table {self._qualified_table_name} does not exists." + columns_list: dict[str, str] = { + d["column_name"]: d["data_type"] for d in columns_inf_result # type: ignore reportUnknownVariableType + } # type: ignore reportUnknownVariableType + return columns_list diff --git a/tests/test_dataframe.py b/tests/test_dataframe.py index c7d8e8e3..72fc839b 100644 --- a/tests/test_dataframe.py +++ b/tests/test_dataframe.py @@ -475,6 +475,22 @@ def test_table_distributed_hash(db: gp.Database): assert row["distributedby"] == "DISTRIBUTED BY (id)" +def test_table_describe(db: gp.Database): + columns = {"a": [1, 2, 3], "b": [1, 2, 3]} + t = db.create_dataframe(columns=columns) + df = t.save_as("const_table_describe", column_names=["a", "b"], schema="test") + result = df.describe() + assert len(result) == 2 + df_s = df[["a", "b"]] + with pytest.raises(Exception) as exc_info: + df_s.describe() + assert "Dataframe is not saved in database" in str(exc_info.value) + df_not_exist = db.create_dataframe(table_name="not_exist_table") + with pytest.raises(Exception) as exc_info: + df_not_exist.describe() + assert 'relation "not_exist_table" does not exist' in str(exc_info.value) + + import pandas as pd