Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap Aggregate API #1889

Merged
merged 29 commits into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4f6784f
[WIP] Wrap Aggregate API
nguyenv Jan 18, 2024
759adb2
Use std::function instead of raw function ptr
nguyenv Jan 18, 2024
d3e4346
[WIP] Support slicing
nguyenv Feb 1, 2024
b993bf4
[WIP] Support null_count for nullable attributes
nguyenv Feb 1, 2024
8e4c555
[WIP] Add unit test file
nguyenv Feb 1, 2024
c9bb97a
[WIP] Format
nguyenv Feb 1, 2024
1fa6712
[WIP] Remove unnecessary pybind11 wrapping
nguyenv Feb 1, 2024
ec500e7
[WIP] Full support for nullable attributes
nguyenv Feb 1, 2024
92aaf00
[WIP] Remove unnecessary member attributes
nguyenv Feb 1, 2024
dc9cb3a
[WIP] Support with query conditions
nguyenv Feb 2, 2024
49ef942
[WIP] Handle invalid non-nullable attribute results
nguyenv Feb 2, 2024
19ece59
[WIP] Only check for invalid on max and min for non-nullable attrs
nguyenv Feb 2, 2024
1700eac
Run Formatter
nguyenv Feb 2, 2024
17b6f66
Correct name to AttrToAggsMap
nguyenv Feb 2, 2024
3fbfb4b
Use std, fix qc test
nguyenv Feb 2, 2024
6851f92
Format
nguyenv Feb 2, 2024
5256f9c
Support with multi_index
nguyenv Feb 6, 2024
e05e5ee
Remove extraneous tests
nguyenv Feb 6, 2024
af508e8
QueryCondition with multi_index
nguyenv Feb 6, 2024
4100dda
Format
nguyenv Feb 6, 2024
5824d10
Clean _get_pyagg
nguyenv Feb 6, 2024
1d03a4f
Clean unit tests
nguyenv Feb 6, 2024
c06208f
Revert QueryCondition Changes -- Will Do In Separate PR
nguyenv Feb 13, 2024
e6e8c23
Format
nguyenv Feb 13, 2024
4fb5947
Use && and ||
nguyenv Feb 13, 2024
dd57947
Update to return None for integers and np.nan for floats on invalid agg
nguyenv Feb 15, 2024
5361950
WIP add doc and examp
nguyenv Feb 22, 2024
c328590
Merge remote-tracking branch 'origin/dev' into viviannguyen/sc-39372/…
ihnorton Mar 2, 2024
786cbcc
Use the fixed-up 'data' variable correctly - ensures at-least-one QC …
ihnorton Mar 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tiledb/cc/common.cc
Expand Up @@ -168,7 +168,8 @@ tiledb_datatype_t np_to_tdb_dtype(py::dtype type) {
if (kind == py::str("U"))
return TILEDB_STRING_UTF8;

TPY_ERROR_LOC("could not handle numpy dtype");
TPY_ERROR_LOC("could not handle numpy dtype: " +
py::getattr(type, "name").cast<std::string>());
}

bool is_tdb_num(tiledb_datatype_t type) {
Expand Down
295 changes: 268 additions & 27 deletions tiledb/core.cc
@@ -1,6 +1,7 @@
#include <chrono>
#include <cmath>
#include <cstring>
#include <functional>
#include <future>
#include <iostream>
#include <map>
Expand Down Expand Up @@ -40,7 +41,6 @@

namespace tiledbpy {

using namespace std;
using namespace tiledb;
namespace py = pybind11;
using namespace pybind11::literals;
Expand Down Expand Up @@ -297,18 +297,260 @@ uint64_t count_zeros(py::array_t<uint8_t> a) {
return count;
}

class PyAgg {

using ByteBuffer = py::array_t<uint8_t>;
using AggToBufferMap = std::map<std::string, ByteBuffer>;
using AttrToAggsMap = std::map<std::string, AggToBufferMap>;

private:
Context ctx_;
std::shared_ptr<tiledb::Array> array_;
std::shared_ptr<tiledb::Query> query_;
AttrToAggsMap result_buffers_;
AttrToAggsMap validity_buffers_;

py::dict original_input_;
std::vector<std::string> attrs_;

public:
PyAgg() = delete;

PyAgg(const Context &ctx, py::object py_array, py::object py_layout,
py::dict attr_to_aggs_input)
: ctx_(ctx), original_input_(attr_to_aggs_input) {
tiledb_array_t *c_array_ = (py::capsule)py_array.attr("__capsule__")();

// We never own this pointer; pass own=false
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, TILEDB_READ);

bool issparse = array_->schema().array_type() == TILEDB_SPARSE;
tiledb_layout_t layout = (tiledb_layout_t)py_layout.cast<int32_t>();
if (!issparse && layout == TILEDB_UNORDERED) {
TPY_ERROR_LOC("TILEDB_UNORDERED read is not supported for dense arrays")
}
query_->set_layout(layout);

// Iterate through the requested attributes
for (auto attr_to_aggs : attr_to_aggs_input) {
auto attr_name = attr_to_aggs.first.cast<std::string>();
auto aggs = attr_to_aggs.second.cast<std::vector<std::string>>();

tiledb::Attribute attr = array_->schema().attribute(attr_name);
attrs_.push_back(attr_name);

// For non-nullable attributes, applying max and min to the empty set is
// undefined. To check for this, we need to also run the count aggregate
// to make sure count != 0
bool requested_max =
std::find(aggs.begin(), aggs.end(), "max") != aggs.end();
bool requested_min =
std::find(aggs.begin(), aggs.end(), "min") != aggs.end();
if (!attr.nullable() && (requested_max || requested_min)) {
// If the user already also requested count, then we don't need to
// request it again
if (std::find(aggs.begin(), aggs.end(), "count") == aggs.end()) {
aggs.push_back("count");
}
}

// Iterate through the aggreate operations to apply on the given attribute
for (auto agg_name : aggs) {
_apply_agg_operator_to_attr(agg_name, attr_name);

// Set the result data buffers
auto *res_buf = &result_buffers_[attr_name][agg_name];
if ("count" == agg_name || "null_count" == agg_name ||
"mean" == agg_name) {
// count and null_count use uint64 and mean uses float64
*res_buf = py::array(py::dtype("uint8"), 8);
} else {
// max, min, and sum use the dtype of the attribute
py::dtype dt(tiledb_dtype(attr.type(), attr.cell_size()));
*res_buf = py::array(py::dtype("uint8"), dt.itemsize());
}
query_->set_data_buffer(attr_name + agg_name, (void *)res_buf->data(),
1);

if (attr.nullable()) {
// For nullable attributes, if the input set for the aggregation
// contains all NULL values, we will not get an aggregate value back
// as this operation is undefined. We need to check the validity
// buffer beforehand to see if we had a valid result
if (!("count" == agg_name || "null_count" == agg_name)) {
auto *val_buf = &validity_buffers_[attr.name()][agg_name];
*val_buf = py::array(py::dtype("uint8"), 1);
query_->set_validity_buffer(attr_name + agg_name,
(uint8_t *)val_buf->data(), 1);
}
}
}
}
}

void _apply_agg_operator_to_attr(const std::string &op_label,
const std::string &attr_name) {
using AggregateFunc =
std::function<ChannelOperation(const Query &, const std::string &)>;

std::unordered_map<std::string, AggregateFunc> label_to_agg_func = {
{"sum", QueryExperimental::create_unary_aggregate<SumOperator>},
{"min", QueryExperimental::create_unary_aggregate<MinOperator>},
{"max", QueryExperimental::create_unary_aggregate<MaxOperator>},
{"mean", QueryExperimental::create_unary_aggregate<MeanOperator>},
{"null_count",
QueryExperimental::create_unary_aggregate<NullCountOperator>},
};

QueryChannel default_channel =
QueryExperimental::get_default_channel(*query_);

if (label_to_agg_func.find(op_label) != label_to_agg_func.end()) {
AggregateFunc create_unary_aggregate = label_to_agg_func.at(op_label);
ChannelOperation op = create_unary_aggregate(*query_, attr_name);
default_channel.apply_aggregate(attr_name + op_label, op);
} else if ("count" == op_label) {
default_channel.apply_aggregate(attr_name + op_label, CountOperation());
} else {
TPY_ERROR_LOC("Invalid channel operation " + op_label +
" passed to apply_aggregate.");
}
}

py::dict get_aggregate() {
query_->submit();

// Cast the results to the correct dtype and output this as a Python dict
py::dict output;
for (auto attr_to_agg : original_input_) {
// Be clear in our variable names for strings as py::dict uses py::str
// keys whereas std::map uses std::string keys
std::string attr_cpp_name = attr_to_agg.first.cast<string>();

py::str attr_py_name(attr_cpp_name);
output[attr_py_name] = py::dict();

tiledb::Attribute attr = array_->schema().attribute(attr_cpp_name);

for (auto agg_py_name : original_input_[attr_py_name]) {
std::string agg_cpp_name = agg_py_name.cast<string>();

if (_is_invalid(attr, agg_cpp_name)) {
output[attr_py_name][agg_py_name] =
_is_integer_dtype(attr) ? py::none() : py::cast(NAN);
} else {
output[attr_py_name][agg_py_name] = _set_result(attr, agg_cpp_name);
}
}
}
return output;
}

bool _is_invalid(tiledb::Attribute attr, std::string agg_name) {
if (attr.nullable()) {
if ("count" == agg_name || "null_count" == agg_name)
return false;

// For nullable attributes, check if the validity buffer returned false
const void *val_buf = validity_buffers_[attr.name()][agg_name].data();
return *((uint8_t *)(val_buf)) == 0;
} else {
// For non-nullable attributes, max and min are undefined for the empty
// set, so we must check the count == 0
if ("max" == agg_name || "min" == agg_name) {
const void *count_buf = result_buffers_[attr.name()]["count"].data();
return *((uint64_t *)(count_buf)) == 0;
}
return false;
}
}

bool _is_integer_dtype(tiledb::Attribute attr) {
switch (attr.type()) {
case TILEDB_INT8:
case TILEDB_INT16:
case TILEDB_UINT8:
case TILEDB_INT32:
case TILEDB_INT64:
case TILEDB_UINT16:
case TILEDB_UINT32:
case TILEDB_UINT64:
return true;
default:
return false;
}
}

py::object _set_result(tiledb::Attribute attr, std::string agg_name) {
const void *agg_buf = result_buffers_[attr.name()][agg_name].data();

if ("mean" == agg_name)
return py::cast(*((double *)agg_buf));

if ("count" == agg_name || "null_count" == agg_name)
return py::cast(*((uint64_t *)agg_buf));

switch (attr.type()) {
case TILEDB_FLOAT32:
return py::cast("sum" == agg_name ? *((double *)agg_buf)
: *((float *)agg_buf));
case TILEDB_FLOAT64:
return py::cast(*((double *)agg_buf));
case TILEDB_INT8:
return py::cast(*((int8_t *)agg_buf));
case TILEDB_UINT8:
return py::cast(*((uint8_t *)agg_buf));
case TILEDB_INT16:
return py::cast(*((int16_t *)agg_buf));
case TILEDB_UINT16:
return py::cast(*((uint16_t *)agg_buf));
case TILEDB_UINT32:
return py::cast(*((uint32_t *)agg_buf));
case TILEDB_INT32:
return py::cast(*((int32_t *)agg_buf));
case TILEDB_INT64:
return py::cast(*((int64_t *)agg_buf));
case TILEDB_UINT64:
return py::cast(*((uint64_t *)agg_buf));
default:
TPY_ERROR_LOC(
"[_cast_agg_result] Invalid tiledb dtype for aggregation result")
}
}

void set_subarray(py::object py_subarray) {
query_->set_subarray(*py_subarray.cast<tiledb::Subarray *>());
}

void set_cond(py::object cond) {
py::object init_pyqc = cond.attr("init_query_condition");

try {
init_pyqc(array_->uri(), attrs_, ctx_);
} catch (tiledb::TileDBError &e) {
TPY_ERROR_LOC(e.what());
} catch (py::error_already_set &e) {
TPY_ERROR_LOC(e.what());
}
auto pyqc = (cond.attr("c_obj")).cast<PyQueryCondition>();
auto qc = pyqc.ptr().get();
query_->set_condition(*qc);
}
};

class PyQuery {

private:
Context ctx_;
shared_ptr<tiledb::Domain> domain_;
shared_ptr<tiledb::ArraySchema> array_schema_;
shared_ptr<tiledb::Array> array_;
shared_ptr<tiledb::Query> query_;
std::shared_ptr<tiledb::Domain> domain_;
std::shared_ptr<tiledb::ArraySchema> array_schema_;
std::shared_ptr<tiledb::Array> array_;
std::shared_ptr<tiledb::Query> query_;
std::vector<std::string> attrs_;
std::vector<std::string> dims_;
map<string, BufferInfo> buffers_;
vector<string> buffers_order_;
std::map<std::string, BufferInfo> buffers_;
std::vector<std::string> buffers_order_;

bool deduplicate_ = true;
bool use_arrow_ = false;
Expand All @@ -320,9 +562,7 @@ class PyQuery {
tiledb_layout_t layout_ = TILEDB_ROW_MAJOR;

// label buffer list
std::unordered_map<string, uint64_t> label_input_buffer_data_;

std::string uri_;
unordered_map<string, uint64_t> label_input_buffer_data_;

public:
tiledb_ctx_t *c_ctx_;
Expand All @@ -347,15 +587,11 @@ class PyQuery {
tiledb_array_t *c_array_ = (py::capsule)array.attr("__capsule__")();

// we never own this pointer, pass own=false
array_ = std::shared_ptr<tiledb::Array>(new Array(ctx_, c_array_, false));

array_schema_ =
std::shared_ptr<tiledb::ArraySchema>(new ArraySchema(array_->schema()));
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);

domain_ =
std::shared_ptr<tiledb::Domain>(new Domain(array_schema_->domain()));
array_schema_ = std::make_shared<tiledb::ArraySchema>(array_->schema());

uri_ = array.attr("uri").cast<std::string>();
domain_ = std::make_shared<tiledb::Domain>(array_schema_->domain());

bool issparse = array_->schema().array_type() == TILEDB_SPARSE;

Expand Down Expand Up @@ -398,8 +634,7 @@ class PyQuery {
}
}

query_ =
std::shared_ptr<tiledb::Query>(new Query(ctx_, *array_, query_mode));
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, query_mode);
// [](Query* p){} /* note: no deleter*/);

if (query_mode == TILEDB_READ) {
Expand All @@ -424,8 +659,7 @@ class PyQuery {
}

void set_subarray(py::object py_subarray) {
tiledb::Subarray *subarray = py_subarray.cast<tiledb::Subarray *>();
query_->set_subarray(*subarray);
query_->set_subarray(*py_subarray.cast<tiledb::Subarray *>());
}

#if defined(TILEDB_SERIALIZATION)
Expand Down Expand Up @@ -456,7 +690,7 @@ class PyQuery {
py::object init_pyqc = cond.attr("init_query_condition");

try {
init_pyqc(uri_, attrs_, ctx_);
init_pyqc(array_->uri(), attrs_, ctx_);
} catch (tiledb::TileDBError &e) {
TPY_ERROR_LOC(e.what());
} catch (py::error_already_set &e) {
Expand Down Expand Up @@ -1538,6 +1772,18 @@ void init_core(py::module &m) {
&PyQuery::_test_alloc_max_bytes)
.def_readonly("retries", &PyQuery::retries_);

py::class_<PyAgg>(m, "PyAgg")
.def(py::init<const Context &, py::object, py::object, py::dict>(),
"ctx"_a, "py_array"_a, "py_layout"_a, "attr_to_aggs_input"_a)
.def("set_subarray", &PyAgg::set_subarray)
.def("set_cond", &PyAgg::set_cond)
.def("get_aggregate", &PyAgg::get_aggregate);

py::class_<PAPair>(m, "PAPair")
.def(py::init())
.def("get_array", &PAPair::get_array)
.def("get_schema", &PAPair::get_schema);

m.def("array_to_buffer", &convert_np);

m.def("init_stats", &init_stats);
Expand All @@ -1548,11 +1794,6 @@ void init_core(py::module &m) {
m.def("get_stats", &get_stats);
m.def("use_stats", &use_stats);

py::class_<PAPair>(m, "PAPair")
.def(py::init())
.def("get_array", &PAPair::get_array)
.def("get_schema", &PAPair::get_schema);

/*
We need to make sure C++ TileDBError is translated to a correctly-typed py
error. Note that using py::exception(..., "TileDBError") creates a new
Expand Down
4 changes: 4 additions & 0 deletions tiledb/libtiledb.pxd
Expand Up @@ -1211,6 +1211,10 @@ cdef class SparseArrayImpl(Array):
cdef class DenseArrayImpl(Array):
cdef _read_dense_subarray(self, object subarray, list attr_names, object cond, tiledb_layout_t layout, bint include_coords)

cdef class Aggregation(object):
cdef Query query
cdef object attr_to_aggs

cdef class Query(object):
cdef Array array
cdef object attrs
Expand Down