diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index 4acd688f4..b00c458e3 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -329,13 +329,26 @@ def __eq__(self, other: Any) -> bool: "version", "upload_date", "url", + "_parquet_url", "dataset", "data_file", + "format", + "cache_format", + } + + cache_fields = { + "_dataset", + "data_file", + "data_pickle_file", + "data_feather_file", + "feather_attribute_file", + "parquet_file", } # check that common keys and values are identical - self_keys = set(self.__dict__.keys()) - server_fields - other_keys = set(other.__dict__.keys()) - server_fields + ignore_fields = server_fields | cache_fields + self_keys = set(self.__dict__.keys()) - ignore_fields + other_keys = set(other.__dict__.keys()) - ignore_fields return self_keys == other_keys and all( self.__dict__[key] == other.__dict__[key] for key in self_keys ) diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index 80da9c842..4598b8985 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -309,6 +309,10 @@ def test_lazy_loading_metadata(self): assert _dataset.features == _compare_dataset.features assert _dataset.qualities == _compare_dataset.qualities + def test_equality_comparison(self): + self.assertEqual(self.iris, self.iris) + self.assertNotEqual(self.iris, self.titanic) + self.assertNotEqual(self.titanic, 'Wrong_object') class OpenMLDatasetTestOnTestServer(TestBase): def setUp(self):