From 71441344476ef59f9f88d34f81f99dde42ef8ff8 Mon Sep 17 00:00:00 2001 From: Agisilaos Kounelis <36283973+kounelisagis@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:29:48 +0200 Subject: [PATCH] Fix `NDRectangle::set_range` overloads (#2157) --- tiledb/libtiledb/current_domain.cc | 265 ++++++++++++---------------- tiledb/tests/test_current_domain.py | 33 +++- 2 files changed, 142 insertions(+), 156 deletions(-) diff --git a/tiledb/libtiledb/current_domain.cc b/tiledb/libtiledb/current_domain.cc index f01896bc0a..2bcef500e8 100644 --- a/tiledb/libtiledb/current_domain.cc +++ b/tiledb/libtiledb/current_domain.cc @@ -22,168 +22,125 @@ void init_current_domain(py::module& m) { .def(py::init()) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) .def( "_set_range", [](NDRectangle& ndrect, const std::string& dim_name, - const std::string& start, - const std::string& end) { - return ndrect.set_range(dim_name, start, end); - }, - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) - .def( - "_set_range", - py::overload_cast( - &NDRectangle::set_range), - py::arg("dim_idx"), - py::arg("start"), - py::arg("end")) + py::object start, + py::object end) { + const tiledb_datatype_t n_type = ndrect.range_dtype(dim_name); + + if (n_type == TILEDB_UINT64) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_INT64) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_UINT32) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_INT32) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_UINT16) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_INT16) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_UINT8) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_INT8) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_FLOAT64) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if (n_type == TILEDB_FLOAT32) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else if ( + n_type == TILEDB_STRING_ASCII || + n_type == TILEDB_STRING_UTF8) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_name, start_val, end_val); + } else { + TPY_ERROR_LOC( + "Unsupported type for NDRectangle's set_range"); + } + }) + .def( "_set_range", [](NDRectangle& ndrect, uint32_t dim_idx, - const std::string& start, - const std::string& end) { - return ndrect.set_range(dim_idx, start, end); - }, - py::arg("dim_name"), - py::arg("start"), - py::arg("end")) + py::object start, + py::object end) { + const tiledb_datatype_t n_type = ndrect.range_dtype(dim_idx); + + if (n_type == TILEDB_UINT64) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_INT64) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_UINT32) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_INT32) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_UINT16) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_INT16) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_UINT8) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_INT8) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_FLOAT64) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if (n_type == TILEDB_FLOAT32) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else if ( + n_type == TILEDB_STRING_ASCII || + n_type == TILEDB_STRING_UTF8) { + auto start_val = start.cast(); + auto end_val = end.cast(); + ndrect.set_range(dim_idx, start_val, end_val); + } else { + TPY_ERROR_LOC( + "Unsupported type for NDRectangle's set_range"); + } + }) .def( "_range", diff --git a/tiledb/tests/test_current_domain.py b/tiledb/tests/test_current_domain.py index 55dff8a56b..780a64f0a1 100644 --- a/tiledb/tests/test_current_domain.py +++ b/tiledb/tests/test_current_domain.py @@ -7,6 +7,8 @@ import tiledb import tiledb.libtiledb as lt +from .common import DiskTestCase + if not (lt.version()[0] == 2 and lt.version()[1] >= 25): pytest.skip( "CurrentDomain is only available in TileDB 2.26 and later", @@ -14,7 +16,7 @@ ) -class NDRectangleTest(unittest.TestCase): +class NDRectangleTest(DiskTestCase): def test_ndrectagle_standalone_string(self): ctx = tiledb.Ctx() dom = tiledb.Domain( @@ -57,8 +59,35 @@ def test_ndrectagle_standalone_integer(self): self.assertEqual(ndrect.range("x"), range_one) self.assertEqual(ndrect.range("y"), range_two) + @pytest.mark.parametrize( + "dtype_, range_", + ( + (np.int32, (0, 1)), + (np.int64, (2, 7)), + (np.uint32, (1, 5)), + (np.uint64, (5, 9)), + (np.float32, (0.0, 1.0)), + (np.float64, (2.0, 4.0)), + (np.dtype(bytes), (b"abc", b"def")), + ), + ) + def test_set_range_different_types(self, dtype_, range_): + domain = tiledb.Domain( + *[ + tiledb.Dim( + name="rows", + domain=(0, 9), + dtype=dtype_, + ), + ] + ) + + ctx = tiledb.Ctx() + ndrect = tiledb.NDRectangle(ctx, domain) + ndrect.set_range("rows", range_[0], range_[1]) + -class CurrentDomainTest(unittest.TestCase): +class CurrentDomainTest(DiskTestCase): def test_current_domain_with_ndrectangle_integer(self): ctx = tiledb.Ctx() dom = tiledb.Domain(