Skip to content

Commit

Permalink
Fix NDRectangle::set_range overloads (#2157)
Browse files Browse the repository at this point in the history
  • Loading branch information
kounelisagis authored Feb 11, 2025
1 parent 18c4296 commit 7144134
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 156 deletions.
265 changes: 111 additions & 154 deletions tiledb/libtiledb/current_domain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,168 +22,125 @@ void init_current_domain(py::module& m) {

.def(py::init<const Context&, const Domain&>())

.def(
"_set_range",
py::overload_cast<const std::string&, uint64_t, uint64_t>(
&NDRectangle::set_range<uint64_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, int64_t, int64_t>(
&NDRectangle::set_range<int64_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, uint32_t, uint32_t>(
&NDRectangle::set_range<uint32_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, int32_t, int32_t>(
&NDRectangle::set_range<int32_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, uint16_t, uint16_t>(
&NDRectangle::set_range<uint16_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, int16_t, int16_t>(
&NDRectangle::set_range<int16_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, uint8_t, uint8_t>(
&NDRectangle::set_range<uint8_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, int8_t, int8_t>(
&NDRectangle::set_range<int8_t>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, double, double>(
&NDRectangle::set_range<double>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<const std::string&, float, float>(
&NDRectangle::set_range<float>),
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
[](NDRectangle& ndrect,
const std::string& dim_name,
const std::string& start,
const std::string& end) {
return ndrect.set_range(dim_name, start, end);
},
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, uint64_t, uint64_t>(
&NDRectangle::set_range<uint64_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, int64_t, int64_t>(
&NDRectangle::set_range<int64_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, uint32_t, uint32_t>(
&NDRectangle::set_range<uint32_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, int32_t, int32_t>(
&NDRectangle::set_range<int32_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, uint16_t, uint16_t>(
&NDRectangle::set_range<uint16_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, int16_t, int16_t>(
&NDRectangle::set_range<int16_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, uint8_t, uint8_t>(
&NDRectangle::set_range<uint8_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, int8_t, int8_t>(
&NDRectangle::set_range<int8_t>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, double, double>(
&NDRectangle::set_range<double>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
.def(
"_set_range",
py::overload_cast<uint32_t, float, float>(
&NDRectangle::set_range<float>),
py::arg("dim_idx"),
py::arg("start"),
py::arg("end"))
py::object start,
py::object end) {
const tiledb_datatype_t n_type = ndrect.range_dtype(dim_name);

if (n_type == TILEDB_UINT64) {
auto start_val = start.cast<uint64_t>();
auto end_val = end.cast<uint64_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_INT64) {
auto start_val = start.cast<int64_t>();
auto end_val = end.cast<int64_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_UINT32) {
auto start_val = start.cast<uint32_t>();
auto end_val = end.cast<uint32_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_INT32) {
auto start_val = start.cast<int32_t>();
auto end_val = end.cast<int32_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_UINT16) {
auto start_val = start.cast<uint16_t>();
auto end_val = end.cast<uint16_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_INT16) {
auto start_val = start.cast<int16_t>();
auto end_val = end.cast<int16_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_UINT8) {
auto start_val = start.cast<uint8_t>();
auto end_val = end.cast<uint8_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_INT8) {
auto start_val = start.cast<int8_t>();
auto end_val = end.cast<int8_t>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_FLOAT64) {
auto start_val = start.cast<double>();
auto end_val = end.cast<double>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (n_type == TILEDB_FLOAT32) {
auto start_val = start.cast<float>();
auto end_val = end.cast<float>();
ndrect.set_range(dim_name, start_val, end_val);
} else if (
n_type == TILEDB_STRING_ASCII ||
n_type == TILEDB_STRING_UTF8) {
auto start_val = start.cast<std::string>();
auto end_val = end.cast<std::string>();
ndrect.set_range(dim_name, start_val, end_val);
} else {
TPY_ERROR_LOC(
"Unsupported type for NDRectangle's set_range");
}
})

.def(
"_set_range",
[](NDRectangle& ndrect,
uint32_t dim_idx,
const std::string& start,
const std::string& end) {
return ndrect.set_range(dim_idx, start, end);
},
py::arg("dim_name"),
py::arg("start"),
py::arg("end"))
py::object start,
py::object end) {
const tiledb_datatype_t n_type = ndrect.range_dtype(dim_idx);

if (n_type == TILEDB_UINT64) {
auto start_val = start.cast<uint64_t>();
auto end_val = end.cast<uint64_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_INT64) {
auto start_val = start.cast<int64_t>();
auto end_val = end.cast<int64_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_UINT32) {
auto start_val = start.cast<uint32_t>();
auto end_val = end.cast<uint32_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_INT32) {
auto start_val = start.cast<int32_t>();
auto end_val = end.cast<int32_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_UINT16) {
auto start_val = start.cast<uint16_t>();
auto end_val = end.cast<uint16_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_INT16) {
auto start_val = start.cast<int16_t>();
auto end_val = end.cast<int16_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_UINT8) {
auto start_val = start.cast<uint8_t>();
auto end_val = end.cast<uint8_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_INT8) {
auto start_val = start.cast<int8_t>();
auto end_val = end.cast<int8_t>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_FLOAT64) {
auto start_val = start.cast<double>();
auto end_val = end.cast<double>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (n_type == TILEDB_FLOAT32) {
auto start_val = start.cast<float>();
auto end_val = end.cast<float>();
ndrect.set_range(dim_idx, start_val, end_val);
} else if (
n_type == TILEDB_STRING_ASCII ||
n_type == TILEDB_STRING_UTF8) {
auto start_val = start.cast<std::string>();
auto end_val = end.cast<std::string>();
ndrect.set_range(dim_idx, start_val, end_val);
} else {
TPY_ERROR_LOC(
"Unsupported type for NDRectangle's set_range");
}
})

.def(
"_range",
Expand Down
33 changes: 31 additions & 2 deletions tiledb/tests/test_current_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import tiledb
import tiledb.libtiledb as lt

from .common import DiskTestCase

if not (lt.version()[0] == 2 and lt.version()[1] >= 25):
pytest.skip(
"CurrentDomain is only available in TileDB 2.26 and later",
allow_module_level=True,
)


class NDRectangleTest(unittest.TestCase):
class NDRectangleTest(DiskTestCase):
def test_ndrectagle_standalone_string(self):
ctx = tiledb.Ctx()
dom = tiledb.Domain(
Expand Down Expand Up @@ -57,8 +59,35 @@ def test_ndrectagle_standalone_integer(self):
self.assertEqual(ndrect.range("x"), range_one)
self.assertEqual(ndrect.range("y"), range_two)

@pytest.mark.parametrize(
"dtype_, range_",
(
(np.int32, (0, 1)),
(np.int64, (2, 7)),
(np.uint32, (1, 5)),
(np.uint64, (5, 9)),
(np.float32, (0.0, 1.0)),
(np.float64, (2.0, 4.0)),
(np.dtype(bytes), (b"abc", b"def")),
),
)
def test_set_range_different_types(self, dtype_, range_):
domain = tiledb.Domain(
*[
tiledb.Dim(
name="rows",
domain=(0, 9),
dtype=dtype_,
),
]
)

ctx = tiledb.Ctx()
ndrect = tiledb.NDRectangle(ctx, domain)
ndrect.set_range("rows", range_[0], range_[1])


class CurrentDomainTest(unittest.TestCase):
class CurrentDomainTest(DiskTestCase):
def test_current_domain_with_ndrectangle_integer(self):
ctx = tiledb.Ctx()
dom = tiledb.Domain(
Expand Down

0 comments on commit 7144134

Please sign in to comment.