Skip to content

Commit

Permalink
Wrap TileDB Aggregate API (#1889)
Browse files Browse the repository at this point in the history
This commit wraps the TileDB Aggregate API (x-ref TileDB-Inc/TileDB#4438).
  • Loading branch information
nguyenv committed Mar 6, 2024
1 parent 7d77872 commit 4f5e260
Show file tree
Hide file tree
Showing 6 changed files with 955 additions and 35 deletions.
3 changes: 2 additions & 1 deletion tiledb/cc/common.cc
Expand Up @@ -168,7 +168,8 @@ tiledb_datatype_t np_to_tdb_dtype(py::dtype type) {
if (kind == py::str("U"))
return TILEDB_STRING_UTF8;

TPY_ERROR_LOC("could not handle numpy dtype");
TPY_ERROR_LOC("could not handle numpy dtype: " +
py::getattr(type, "name").cast<std::string>());
}

bool is_tdb_num(tiledb_datatype_t type) {
Expand Down
295 changes: 268 additions & 27 deletions tiledb/core.cc
@@ -1,6 +1,7 @@
#include <chrono>
#include <cmath>
#include <cstring>
#include <functional>
#include <future>
#include <iostream>
#include <map>
Expand Down Expand Up @@ -40,7 +41,6 @@

namespace tiledbpy {

using namespace std;
using namespace tiledb;
namespace py = pybind11;
using namespace pybind11::literals;
Expand Down Expand Up @@ -297,18 +297,260 @@ uint64_t count_zeros(py::array_t<uint8_t> a) {
return count;
}

class PyAgg {

using ByteBuffer = py::array_t<uint8_t>;
using AggToBufferMap = std::map<std::string, ByteBuffer>;
using AttrToAggsMap = std::map<std::string, AggToBufferMap>;

private:
Context ctx_;
std::shared_ptr<tiledb::Array> array_;
std::shared_ptr<tiledb::Query> query_;
AttrToAggsMap result_buffers_;
AttrToAggsMap validity_buffers_;

py::dict original_input_;
std::vector<std::string> attrs_;

public:
PyAgg() = delete;

PyAgg(const Context &ctx, py::object py_array, py::object py_layout,
py::dict attr_to_aggs_input)
: ctx_(ctx), original_input_(attr_to_aggs_input) {
tiledb_array_t *c_array_ = (py::capsule)py_array.attr("__capsule__")();

// We never own this pointer; pass own=false
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, TILEDB_READ);

bool issparse = array_->schema().array_type() == TILEDB_SPARSE;
tiledb_layout_t layout = (tiledb_layout_t)py_layout.cast<int32_t>();
if (!issparse && layout == TILEDB_UNORDERED) {
TPY_ERROR_LOC("TILEDB_UNORDERED read is not supported for dense arrays")
}
query_->set_layout(layout);

// Iterate through the requested attributes
for (auto attr_to_aggs : attr_to_aggs_input) {
auto attr_name = attr_to_aggs.first.cast<std::string>();
auto aggs = attr_to_aggs.second.cast<std::vector<std::string>>();

tiledb::Attribute attr = array_->schema().attribute(attr_name);
attrs_.push_back(attr_name);

// For non-nullable attributes, applying max and min to the empty set is
// undefined. To check for this, we need to also run the count aggregate
// to make sure count != 0
bool requested_max =
std::find(aggs.begin(), aggs.end(), "max") != aggs.end();
bool requested_min =
std::find(aggs.begin(), aggs.end(), "min") != aggs.end();
if (!attr.nullable() && (requested_max || requested_min)) {
// If the user already also requested count, then we don't need to
// request it again
if (std::find(aggs.begin(), aggs.end(), "count") == aggs.end()) {
aggs.push_back("count");
}
}

// Iterate through the aggreate operations to apply on the given attribute
for (auto agg_name : aggs) {
_apply_agg_operator_to_attr(agg_name, attr_name);

// Set the result data buffers
auto *res_buf = &result_buffers_[attr_name][agg_name];
if ("count" == agg_name || "null_count" == agg_name ||
"mean" == agg_name) {
// count and null_count use uint64 and mean uses float64
*res_buf = py::array(py::dtype("uint8"), 8);
} else {
// max, min, and sum use the dtype of the attribute
py::dtype dt(tiledb_dtype(attr.type(), attr.cell_size()));
*res_buf = py::array(py::dtype("uint8"), dt.itemsize());
}
query_->set_data_buffer(attr_name + agg_name, (void *)res_buf->data(),
1);

if (attr.nullable()) {
// For nullable attributes, if the input set for the aggregation
// contains all NULL values, we will not get an aggregate value back
// as this operation is undefined. We need to check the validity
// buffer beforehand to see if we had a valid result
if (!("count" == agg_name || "null_count" == agg_name)) {
auto *val_buf = &validity_buffers_[attr.name()][agg_name];
*val_buf = py::array(py::dtype("uint8"), 1);
query_->set_validity_buffer(attr_name + agg_name,
(uint8_t *)val_buf->data(), 1);
}
}
}
}
}

void _apply_agg_operator_to_attr(const std::string &op_label,
const std::string &attr_name) {
using AggregateFunc =
std::function<ChannelOperation(const Query &, const std::string &)>;

std::unordered_map<std::string, AggregateFunc> label_to_agg_func = {
{"sum", QueryExperimental::create_unary_aggregate<SumOperator>},
{"min", QueryExperimental::create_unary_aggregate<MinOperator>},
{"max", QueryExperimental::create_unary_aggregate<MaxOperator>},
{"mean", QueryExperimental::create_unary_aggregate<MeanOperator>},
{"null_count",
QueryExperimental::create_unary_aggregate<NullCountOperator>},
};

QueryChannel default_channel =
QueryExperimental::get_default_channel(*query_);

if (label_to_agg_func.find(op_label) != label_to_agg_func.end()) {
AggregateFunc create_unary_aggregate = label_to_agg_func.at(op_label);
ChannelOperation op = create_unary_aggregate(*query_, attr_name);
default_channel.apply_aggregate(attr_name + op_label, op);
} else if ("count" == op_label) {
default_channel.apply_aggregate(attr_name + op_label, CountOperation());
} else {
TPY_ERROR_LOC("Invalid channel operation " + op_label +
" passed to apply_aggregate.");
}
}

py::dict get_aggregate() {
query_->submit();

// Cast the results to the correct dtype and output this as a Python dict
py::dict output;
for (auto attr_to_agg : original_input_) {
// Be clear in our variable names for strings as py::dict uses py::str
// keys whereas std::map uses std::string keys
std::string attr_cpp_name = attr_to_agg.first.cast<string>();

py::str attr_py_name(attr_cpp_name);
output[attr_py_name] = py::dict();

tiledb::Attribute attr = array_->schema().attribute(attr_cpp_name);

for (auto agg_py_name : original_input_[attr_py_name]) {
std::string agg_cpp_name = agg_py_name.cast<string>();

if (_is_invalid(attr, agg_cpp_name)) {
output[attr_py_name][agg_py_name] =
_is_integer_dtype(attr) ? py::none() : py::cast(NAN);
} else {
output[attr_py_name][agg_py_name] = _set_result(attr, agg_cpp_name);
}
}
}
return output;
}

bool _is_invalid(tiledb::Attribute attr, std::string agg_name) {
if (attr.nullable()) {
if ("count" == agg_name || "null_count" == agg_name)
return false;

// For nullable attributes, check if the validity buffer returned false
const void *val_buf = validity_buffers_[attr.name()][agg_name].data();
return *((uint8_t *)(val_buf)) == 0;
} else {
// For non-nullable attributes, max and min are undefined for the empty
// set, so we must check the count == 0
if ("max" == agg_name || "min" == agg_name) {
const void *count_buf = result_buffers_[attr.name()]["count"].data();
return *((uint64_t *)(count_buf)) == 0;
}
return false;
}
}

bool _is_integer_dtype(tiledb::Attribute attr) {
switch (attr.type()) {
case TILEDB_INT8:
case TILEDB_INT16:
case TILEDB_UINT8:
case TILEDB_INT32:
case TILEDB_INT64:
case TILEDB_UINT16:
case TILEDB_UINT32:
case TILEDB_UINT64:
return true;
default:
return false;
}
}

py::object _set_result(tiledb::Attribute attr, std::string agg_name) {
const void *agg_buf = result_buffers_[attr.name()][agg_name].data();

if ("mean" == agg_name)
return py::cast(*((double *)agg_buf));

if ("count" == agg_name || "null_count" == agg_name)
return py::cast(*((uint64_t *)agg_buf));

switch (attr.type()) {
case TILEDB_FLOAT32:
return py::cast("sum" == agg_name ? *((double *)agg_buf)
: *((float *)agg_buf));
case TILEDB_FLOAT64:
return py::cast(*((double *)agg_buf));
case TILEDB_INT8:
return py::cast(*((int8_t *)agg_buf));
case TILEDB_UINT8:
return py::cast(*((uint8_t *)agg_buf));
case TILEDB_INT16:
return py::cast(*((int16_t *)agg_buf));
case TILEDB_UINT16:
return py::cast(*((uint16_t *)agg_buf));
case TILEDB_UINT32:
return py::cast(*((uint32_t *)agg_buf));
case TILEDB_INT32:
return py::cast(*((int32_t *)agg_buf));
case TILEDB_INT64:
return py::cast(*((int64_t *)agg_buf));
case TILEDB_UINT64:
return py::cast(*((uint64_t *)agg_buf));
default:
TPY_ERROR_LOC(
"[_cast_agg_result] Invalid tiledb dtype for aggregation result")
}
}

void set_subarray(py::object py_subarray) {
query_->set_subarray(*py_subarray.cast<tiledb::Subarray *>());
}

void set_cond(py::object cond) {
py::object init_pyqc = cond.attr("init_query_condition");

try {
init_pyqc(array_->uri(), attrs_, ctx_);
} catch (tiledb::TileDBError &e) {
TPY_ERROR_LOC(e.what());
} catch (py::error_already_set &e) {
TPY_ERROR_LOC(e.what());
}
auto pyqc = (cond.attr("c_obj")).cast<PyQueryCondition>();
auto qc = pyqc.ptr().get();
query_->set_condition(*qc);
}
};

class PyQuery {

private:
Context ctx_;
shared_ptr<tiledb::Domain> domain_;
shared_ptr<tiledb::ArraySchema> array_schema_;
shared_ptr<tiledb::Array> array_;
shared_ptr<tiledb::Query> query_;
std::shared_ptr<tiledb::Domain> domain_;
std::shared_ptr<tiledb::ArraySchema> array_schema_;
std::shared_ptr<tiledb::Array> array_;
std::shared_ptr<tiledb::Query> query_;
std::vector<std::string> attrs_;
std::vector<std::string> dims_;
map<string, BufferInfo> buffers_;
vector<string> buffers_order_;
std::map<std::string, BufferInfo> buffers_;
std::vector<std::string> buffers_order_;

bool deduplicate_ = true;
bool use_arrow_ = false;
Expand All @@ -320,9 +562,7 @@ class PyQuery {
tiledb_layout_t layout_ = TILEDB_ROW_MAJOR;

// label buffer list
std::unordered_map<string, uint64_t> label_input_buffer_data_;

std::string uri_;
unordered_map<string, uint64_t> label_input_buffer_data_;

public:
tiledb_ctx_t *c_ctx_;
Expand All @@ -347,15 +587,11 @@ class PyQuery {
tiledb_array_t *c_array_ = (py::capsule)array.attr("__capsule__")();

// we never own this pointer, pass own=false
array_ = std::shared_ptr<tiledb::Array>(new Array(ctx_, c_array_, false));

array_schema_ =
std::shared_ptr<tiledb::ArraySchema>(new ArraySchema(array_->schema()));
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);

domain_ =
std::shared_ptr<tiledb::Domain>(new Domain(array_schema_->domain()));
array_schema_ = std::make_shared<tiledb::ArraySchema>(array_->schema());

uri_ = array.attr("uri").cast<std::string>();
domain_ = std::make_shared<tiledb::Domain>(array_schema_->domain());

bool issparse = array_->schema().array_type() == TILEDB_SPARSE;

Expand Down Expand Up @@ -398,8 +634,7 @@ class PyQuery {
}
}

query_ =
std::shared_ptr<tiledb::Query>(new Query(ctx_, *array_, query_mode));
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, query_mode);
// [](Query* p){} /* note: no deleter*/);

if (query_mode == TILEDB_READ) {
Expand All @@ -424,8 +659,7 @@ class PyQuery {
}

void set_subarray(py::object py_subarray) {
tiledb::Subarray *subarray = py_subarray.cast<tiledb::Subarray *>();
query_->set_subarray(*subarray);
query_->set_subarray(*py_subarray.cast<tiledb::Subarray *>());
}

#if defined(TILEDB_SERIALIZATION)
Expand Down Expand Up @@ -456,7 +690,7 @@ class PyQuery {
py::object init_pyqc = cond.attr("init_query_condition");

try {
init_pyqc(uri_, attrs_, ctx_);
init_pyqc(array_->uri(), attrs_, ctx_);
} catch (tiledb::TileDBError &e) {
TPY_ERROR_LOC(e.what());
} catch (py::error_already_set &e) {
Expand Down Expand Up @@ -1538,6 +1772,18 @@ void init_core(py::module &m) {
&PyQuery::_test_alloc_max_bytes)
.def_readonly("retries", &PyQuery::retries_);

py::class_<PyAgg>(m, "PyAgg")
.def(py::init<const Context &, py::object, py::object, py::dict>(),
"ctx"_a, "py_array"_a, "py_layout"_a, "attr_to_aggs_input"_a)
.def("set_subarray", &PyAgg::set_subarray)
.def("set_cond", &PyAgg::set_cond)
.def("get_aggregate", &PyAgg::get_aggregate);

py::class_<PAPair>(m, "PAPair")
.def(py::init())
.def("get_array", &PAPair::get_array)
.def("get_schema", &PAPair::get_schema);

m.def("array_to_buffer", &convert_np);

m.def("init_stats", &init_stats);
Expand All @@ -1548,11 +1794,6 @@ void init_core(py::module &m) {
m.def("get_stats", &get_stats);
m.def("use_stats", &use_stats);

py::class_<PAPair>(m, "PAPair")
.def(py::init())
.def("get_array", &PAPair::get_array)
.def("get_schema", &PAPair::get_schema);

/*
We need to make sure C++ TileDBError is translated to a correctly-typed py
error. Note that using py::exception(..., "TileDBError") creates a new
Expand Down
4 changes: 4 additions & 0 deletions tiledb/libtiledb.pxd
Expand Up @@ -1211,6 +1211,10 @@ cdef class SparseArrayImpl(Array):
cdef class DenseArrayImpl(Array):
cdef _read_dense_subarray(self, object subarray, list attr_names, object cond, tiledb_layout_t layout, bint include_coords)

cdef class Aggregation(object):
cdef Query query
cdef object attr_to_aggs

cdef class Query(object):
cdef Array array
cdef object attrs
Expand Down

0 comments on commit 4f5e260

Please sign in to comment.