Skip to content

Commit

Permalink
[WIP] Support with query conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Feb 2, 2024
1 parent 891819a commit cbe8fbf
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 101 deletions.
150 changes: 82 additions & 68 deletions tiledb/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,67 +294,68 @@ uint64_t count_zeros(py::array_t<uint8_t> a) {

class PyAgg {

using ByteBuffer = py::array_t<uint8_t>;
using AggToBufferMap = std::map<std::string, ByteBuffer>;
using AttrToBuffersMap = std::map<std::string, AggToBufferMap>;

private:
Context ctx_;
shared_ptr<tiledb::ArraySchema> array_schema_;
shared_ptr<tiledb::Array> array_;
shared_ptr<tiledb::Query> query_;
map<string, map<string, py::array_t<uint8_t>>> result_buffers_;
map<string, map<string, py::array_t<uint8_t>>> validity_buffers_;
std::shared_ptr<tiledb::Array> array_;
std::shared_ptr<tiledb::Query> query_;
std::vector<std::string> attrs_;
AttrToBuffersMap result_buffers_;
AttrToBuffersMap validity_buffers_;

public:
PyAgg() = delete;

PyAgg(const Context &ctx, py::object array, py::object py_layout,
py::dict attr_to_aggs_map)
PyAgg(const Context &ctx, py::object py_array, py::object py_layout,
py::dict attr_to_aggs_input)
: ctx_(ctx){
tiledb_array_t *c_array_ = (py::capsule)array.attr("__capsule__")();

// we never own this pointer, pass own=false
array_ = std::shared_ptr<tiledb::Array>(new Array(ctx_, c_array_, false));

array_schema_ =
std::shared_ptr<tiledb::ArraySchema>(new ArraySchema(array_->schema()));
tiledb_array_t *c_array_ = (py::capsule)py_array.attr("__capsule__")();

// We cannot apply aggregates to a channel with an instantiated Query
// so we "reset" the Query object here
query_ = shared_ptr<tiledb::Query>(new Query(ctx_, *array_, TILEDB_READ));
// We never own this pointer; pass own=false
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, TILEDB_READ);

bool issparse = array_schema_->array_type() == TILEDB_SPARSE;
bool issparse = array_->schema().array_type() == TILEDB_SPARSE;
tiledb_layout_t layout = (tiledb_layout_t)py_layout.cast<int32_t>();
if (!issparse && layout == TILEDB_UNORDERED) {
TPY_ERROR_LOC("TILEDB_UNORDERED read is not supported for dense arrays")
}
query_->set_layout(layout);

// Set the data buffers for each attribute & aggregation function passed in
// by the user
for (auto attr_to_aggs : attr_to_aggs_map) {
std::string attr_name(attr_to_aggs.first.cast<string>());
auto attr_type = array_schema_->attribute(attr_name).type();
auto attr_ncell = array_schema_->attribute(attr_name).cell_size();
auto attr_nullable = array_schema_->attribute(attr_name).nullable();

for (auto agg : attr_to_aggs.second) {
// TODO separate into its own function
std::string agg_name(agg.cast<string>());
_apply_agg_to_attr(agg_name, attr_name);
// Iterate through the requested attributes
for (auto attr_to_aggs : attr_to_aggs_input) {
auto attr_name = attr_to_aggs.first.cast<string>();
auto aggregates = attr_to_aggs.second;

tiledb::Attribute attr = array_->schema().attribute(attr_name);
attrs_.push_back(attr_name);

// Iterate through the aggreate operations to apply on the given attribute
for (auto agg : aggregates) {
auto agg_name = agg.cast<string>();

_apply_agg_operator_to_attr(agg_name, attr_name);

// Set the buffer for the aggregation query
auto* res_buf = &result_buffers_[attr_name][agg_name];
if("count" == agg_name or "null_count" == agg_name or "mean" == agg_name){
// count and null_count use uint64 and mean uses float64
*res_buf = py::array(py::dtype("uint8"), 8);
}else{
// max, min, and sum use the dtype of the attribute
py::dtype dt(tiledb_dtype(attr_type, attr_ncell));
py::dtype dt(tiledb_dtype(attr.type(), attr.cell_size()));
*res_buf = py::array(py::dtype("uint8"), dt.itemsize());
}
query_->set_data_buffer(agg_name, (void*)res_buf->data(), 1);

// If the requested data contains all NULL values, we will not get an
// aggregate value back. We need to check the validity buffer beforehand
// For nullable attributes, if the input set for the aggregation contains
// all NULL values, we will not get an aggregate value back as this
// operation is undefined. We need to check the validity buffer beforehand
// to see if we had a valid result
if(attr_nullable and !("count" == agg_name or "null_count" == agg_name)){
if(attr.nullable() and !("count" == agg_name or "null_count" == agg_name)){
auto* val_buf = &validity_buffers_[attr_name][agg_name];
*val_buf = py::array(py::dtype("uint8"), 1);
query_->set_validity_buffer(agg_name, (uint8_t*)val_buf->data(), 1);
Expand All @@ -363,7 +364,7 @@ class PyAgg {
}
}

void _apply_agg_to_attr(
void _apply_agg_operator_to_attr(
const std::string &op_label, const std::string &attr_name) {
using AggregateFunc = std::function<ChannelOperation(
const Query &, const std::string &)>;
Expand All @@ -384,11 +385,11 @@ class PyAgg {
AggregateFunc create_unary_aggregate = label_to_agg_func.at(op_label);
ChannelOperation op = create_unary_aggregate(*query_, attr_name);
default_channel.apply_aggregate(op_label, op);
} else if (op_label == "count") {
} else if ("count" == op_label) {
default_channel.apply_aggregate(op_label, CountOperation());
} else {
TPY_ERROR_LOC("Invalid channel operation " + op_label +
" passed to apply_aggregate.");
" passed to apply_aggregate.");
}
}

Expand All @@ -401,8 +402,8 @@ class PyAgg {
py::str attr_name(attr.first);
result[attr_name] = py::dict();

auto attr_type = array_schema_->attribute(attr_name).type();
auto attr_nullable = array_schema_->attribute(attr_name).nullable();
auto attr_type = array_->schema().attribute(attr_name).type();
auto attr_nullable = array_->schema().attribute(attr_name).nullable();

for (auto agg : attr.second) {
py::str agg_name(agg.first);
Expand Down Expand Up @@ -470,20 +471,35 @@ class PyAgg {
void set_subarray(py::object py_subarray) {
query_->set_subarray(*py_subarray.cast<tiledb::Subarray *>());
}

void set_cond(py::object cond) {
py::object init_pyqc = cond.attr("init_query_condition");

try {
init_pyqc(array_->uri(), attrs_, ctx_);
} catch (tiledb::TileDBError &e) {
TPY_ERROR_LOC(e.what());
} catch (py::error_already_set &e) {
TPY_ERROR_LOC(e.what());
}
auto pyqc = (cond.attr("c_obj")).cast<PyQueryCondition>();
auto qc = pyqc.ptr().get();
query_->set_condition(*qc);
}
};

class PyQuery {

private:
Context ctx_;
shared_ptr<tiledb::Domain> domain_;
shared_ptr<tiledb::ArraySchema> array_schema_;
shared_ptr<tiledb::Array> array_;
shared_ptr<tiledb::Query> query_;
vector<string> attrs_;
vector<string> dims_;
map<string, BufferInfo> buffers_;
vector<string> buffers_order_;
std::shared_ptr<tiledb::Domain> domain_;
std::shared_ptr<tiledb::ArraySchema> array_schema_;
std::shared_ptr<tiledb::Array> array_;
std::shared_ptr<tiledb::Query> query_;
std::vector<std::string> attrs_;
std::vector<std::string> dims_;
std::map<std::string, BufferInfo> buffers_;
std::vector<std::string> buffers_order_;

bool deduplicate_ = true;
bool use_arrow_ = false;
Expand All @@ -497,8 +513,6 @@ class PyQuery {
// label buffer list
unordered_map<string, uint64_t> label_input_buffer_data_;

string uri_;

public:
tiledb_ctx_t *c_ctx_;
tiledb_array_t *c_array_;
Expand All @@ -522,15 +536,11 @@ class PyQuery {
tiledb_array_t *c_array_ = (py::capsule)array.attr("__capsule__")();

// we never own this pointer, pass own=false
array_ = std::shared_ptr<tiledb::Array>(new Array(ctx_, c_array_, false));

array_schema_ =
std::shared_ptr<tiledb::ArraySchema>(new ArraySchema(array_->schema()));
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);

domain_ =
std::shared_ptr<tiledb::Domain>(new Domain(array_schema_->domain()));
array_schema_ = std::make_shared<tiledb::ArraySchema>(array_->schema());

uri_ = array.attr("uri").cast<std::string>();
domain_ = std::make_shared<tiledb::Domain>(array_schema_->domain());

bool issparse = array_->schema().array_type() == TILEDB_SPARSE;

Expand Down Expand Up @@ -573,8 +583,7 @@ class PyQuery {
}
}

query_ =
std::shared_ptr<tiledb::Query>(new Query(ctx_, *array_, query_mode));
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, query_mode);
// [](Query* p){} /* note: no deleter*/);

if (query_mode == TILEDB_READ) {
Expand Down Expand Up @@ -630,7 +639,7 @@ class PyQuery {
py::object init_pyqc = cond.attr("init_query_condition");

try {
init_pyqc(uri_, attrs_, ctx_);
init_pyqc(array_->uri(), attrs_, ctx_);
} catch (tiledb::TileDBError &e) {
TPY_ERROR_LOC(e.what());
} catch (py::error_already_set &e) {
Expand Down Expand Up @@ -1712,10 +1721,20 @@ void init_core(py::module &m) {
&PyQuery::_test_alloc_max_bytes)
.def_readonly("retries", &PyQuery::retries_);

py::class_<PyAgg>(m, "PyAgg")
.def(py::init<const Context &, py::object, py::object, py::dict>())
.def("set_subarray", &PyAgg::set_subarray)
.def("get_aggregate", &PyAgg::get_aggregate);
py::class_<PyAgg>(m, "PyAgg")
.def(py::init<const Context &, py::object, py::object, py::dict>(),
"ctx"_a,
"py_array"_a,
"py_layout"_a,
"attr_to_aggs_input"_a)
.def("set_subarray", &PyAgg::set_subarray)
.def("set_cond", &PyAgg::set_cond)
.def("get_aggregate", &PyAgg::get_aggregate);

py::class_<PAPair>(m, "PAPair")
.def(py::init())
.def("get_array", &PAPair::get_array)
.def("get_schema", &PAPair::get_schema);

m.def("array_to_buffer", &convert_np);

Expand All @@ -1727,11 +1746,6 @@ void init_core(py::module &m) {
m.def("get_stats", &get_stats);
m.def("use_stats", &use_stats);

py::class_<PAPair>(m, "PAPair")
.def(py::init())
.def("get_array", &PAPair::get_array)
.def("get_schema", &PAPair::get_schema);

/*
We need to make sure C++ TileDBError is translated to a correctly-typed py
error. Note that using py::exception(..., "TileDBError") creates a new
Expand Down
49 changes: 18 additions & 31 deletions tiledb/libtiledb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1738,7 +1738,10 @@ cdef class Query(object):
if not array.schema.has_attr(name):
raise TileDBError(f"Selected attribute does not exist: '{name}'")
self.attrs = attrs
self.cond = cond

if cond is not None and not isinstance(cond, str):
raise TypeError("`cond` expects type str.")
self.cond = None if cond == "" else cond

if order == None:
if array.schema.sparse:
Expand Down Expand Up @@ -1784,8 +1787,16 @@ cdef class Query(object):
subarray.add_ranges([list([x]) for x in dim_ranges])
q.set_subarray(subarray)

if self.cond is not None:
from .query_condition import QueryCondition
q.set_cond(QueryCondition(self.cond))

result = q.get_aggregate()

# Clear the pyquery after calculating the aggregates, otherwise the
# PyAgg object sticks around when calling subsequent Array.query
self.array.pyquery = None

# If there was only one attribute, just show the aggregate results
if len(result) == 1:
result = result[list(result.keys())[0]]
Expand Down Expand Up @@ -1842,7 +1853,8 @@ cdef class Query(object):
"'G' (TILEDB_GLOBAL_ORDER), "\
"or 'U' (TILEDB_UNORDERED)")

self.array.pyquery = PyAgg(self.array._ctx_(), self.array, <int32_t>layout, attr_to_aggs_map)
self.array.pyquery = PyAgg(self.array._ctx_(), self.array,
<int32_t>layout, attr_to_aggs_map)

return self

Expand Down Expand Up @@ -2218,22 +2230,9 @@ cdef class DenseArrayImpl(Array):
q = PyQuery(self._ctx_(), self, tuple(attr_names), tuple(), <int32_t>layout, False)
self.pyquery = q


if cond is not None and cond != "":
if cond is not None:
from .query_condition import QueryCondition

if isinstance(cond, str):
q.set_cond(QueryCondition(cond))
elif isinstance(cond, QueryCondition):
raise TileDBError(
"Passing `tiledb.QueryCondition` to `cond` is no longer "
"supported as of 0.19.0. Instead of "
"`cond=tiledb.QueryCondition('expression')` "
"you must use `cond='expression'`. This message will be "
"removed in 0.21.0.",
)
else:
raise TypeError("`cond` expects type str.")
q.set_cond(QueryCondition(self.cond))

q.set_subarray(subarray)
q.submit()
Expand Down Expand Up @@ -3329,21 +3328,9 @@ cdef class SparseArrayImpl(Array):
q = PyQuery(self._ctx_(), self, tuple(attr_names), tuple(), <int32_t>layout, False)
self.pyquery = q

if cond is not None and cond != "":
if cond is not None:
from .query_condition import QueryCondition

if isinstance(cond, str):
q.set_cond(QueryCondition(cond))
elif isinstance(cond, QueryCondition):
raise TileDBError(
"Passing `tiledb.QueryCondition` to `cond` is no longer "
"supported as of 0.19.0. Instead of "
"`cond=tiledb.QueryCondition('expression')` "
"you must use `cond='expression'`. This message will be "
"removed in 0.21.0.",
)
else:
raise TypeError("`cond` expects type str.")
q.set_cond(QueryCondition(cond))

if self.mode == "r":
q.set_subarray(subarray)
Expand Down
Loading

0 comments on commit cbe8fbf

Please sign in to comment.