From f7fe2f8b8bb9214e4094eb3c25e8ad9d6ec4b201 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Thu, 26 Dec 2024 08:52:56 -0500 Subject: [PATCH 1/2] fix(api): ensure memtable schema and columns match --- ibis/expr/api.py | 6 ++++++ ibis/expr/tests/test_api.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 5c726715b447..147c4a702a15 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -408,6 +408,12 @@ def memtable( "passing `columns` and schema` is ambiguous; " "pass one or the other but not both" ) + + if schema is not None: + import ibis + + schema = ibis.schema(schema) + return _memtable(data, name=name, schema=schema, columns=columns) diff --git a/ibis/expr/tests/test_api.py b/ibis/expr/tests/test_api.py index 091ab3a259d6..768f6a0eae3c 100644 --- a/ibis/expr/tests/test_api.py +++ b/ibis/expr/tests/test_api.py @@ -143,6 +143,15 @@ def test_duplicate_columns_in_memtable_not_allowed(): ibis.memtable(df) +def test_memtable_column_names_match_schema(): + pd = pytest.importorskip("pandas") + + df = pd.DataFrame([[1, 2], [3, 4]]) + schema = {"a": "int64", "b": "int64"} + t = ibis.memtable(df, schema=schema) + t.a.execute() # Raises a KeyError if the column name does not match + + @pytest.mark.parametrize( "op", [ From 5f5ab22a4b4c57b6a8238d4c671bf0b781199ffe Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Thu, 26 Dec 2024 17:28:35 -0500 Subject: [PATCH 2/2] test(api): update test to avoid executing anything --- ibis/expr/tests/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ibis/expr/tests/test_api.py b/ibis/expr/tests/test_api.py index 768f6a0eae3c..f58e58471614 100644 --- a/ibis/expr/tests/test_api.py +++ b/ibis/expr/tests/test_api.py @@ -149,7 +149,7 @@ def test_memtable_column_names_match_schema(): df = pd.DataFrame([[1, 2], [3, 4]]) schema = {"a": "int64", "b": "int64"} t = ibis.memtable(df, schema=schema) - t.a.execute() # Raises a KeyError if the column name does not match + assert t.op().data.to_frame().columns.tolist() == ["a", "b"] @pytest.mark.parametrize(