diff --git a/orca_test/orca_test.py b/orca_test/orca_test.py index 1caf605..d6b57ae 100644 --- a/orca_test/orca_test.py +++ b/orca_test/orca_test.py @@ -289,7 +289,7 @@ def assert_column_is_registered(table_name, column_name): assert_table_can_be_generated(table_name) t = orca.get_table(table_name) - if (column_name not in t.columns) and (column_name != t.index.name): + if (column_name not in t.columns) and (column_name not in t.index.names): msg = "Column '%s' is not registered in table '%s'" % (column_name, table_name) raise OrcaAssertionError(msg) return @@ -311,7 +311,7 @@ def assert_column_not_registered(table_name, column_name): assert_table_can_be_generated(table_name) t = orca.get_table(table_name) - if (column_name in t.columns) or (column_name == t.index.name): + if (column_name in t.columns) or (column_name in t.index.names): msg = "Column '%s' is already registered in table '%s'" % (column_name, table_name) raise OrcaAssertionError(msg) return @@ -340,7 +340,7 @@ def assert_column_can_be_generated(table_name, column_name): t = orca.get_table(table_name) # t.column_type() fails for index columns, so we have to check for them separately - if column_name == t.index.name: + if column_name in t.index.names: return elif t.column_type(column_name) == 'function': @@ -372,6 +372,10 @@ def assert_column_is_primary_key(table_name, column_name): assert_column_can_be_generated(table_name, column_name) idx = orca.get_table(table_name).index + if len(idx.names) > 1: + msg = "The table '%s' has a multi-index, and primary key checks are not yet supported." \ + % table_name + raise OrcaAssertionError(msg) if idx.name != column_name: msg = "Column '%s' is not set as the index of table '%s'" \ % (column_name, table_name) @@ -448,9 +452,8 @@ def get_column_or_index(table_name, column_name): assert_column_can_be_generated(table_name, column_name) t = orca.get_table(table_name) - if column_name == t.index.name: - return pd.Series(t.index) - + if column_name in t.index.names: + return t.to_frame([]).reset_index()[column_name] else: return t.get_column(column_name)