Skip to content

Commit

Permalink
Merge pull request #1273 from happy-san/v0.26-facets
Browse files Browse the repository at this point in the history
Filter scoring
  • Loading branch information
kishorenc committed Oct 4, 2023
2 parents ce4b8e3 + 0d81a66 commit a4ab04c
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 146 deletions.
17 changes: 14 additions & 3 deletions include/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,14 @@ struct sort_by {
};

struct eval_t {
filter_node_t* filter_tree_root = nullptr;
uint32_t* ids = nullptr;
uint32_t size = 0;
filter_node_t* filter_trees = nullptr;
std::vector<uint32_t*> eval_ids_vec;
std::vector<uint32_t> eval_ids_count_vec;
std::vector<int64_t> scores;
};

std::string name;
std::vector<std::string> eval_expressions;
std::string order;

// for text_match score bucketing
Expand All @@ -523,6 +525,13 @@ struct sort_by {

}

sort_by(std::vector<std::string> eval_expressions, std::vector<int64_t> scores, std::string order):
eval_expressions(std::move(eval_expressions)), order(std::move(order)), text_match_buckets(0), geopoint(0), exclude_radius(0),
geo_precision(0), missing_values(normal) {
name = sort_field_const::eval;
eval.scores = std::move(scores);
}

sort_by(const std::string &name, const std::string &order, uint32_t text_match_buckets, int64_t geopoint,
uint32_t exclude_radius, uint32_t geo_precision) :
name(name), order(order), text_match_buckets(text_match_buckets),
Expand All @@ -535,6 +544,7 @@ struct sort_by {
if (&other == this)
return;
name = other.name;
eval_expressions = other.eval_expressions;
order = other.order;
text_match_buckets = other.text_match_buckets;
geopoint = other.geopoint;
Expand All @@ -547,6 +557,7 @@ struct sort_by {

sort_by& operator=(const sort_by& other) {
name = other.name;
eval_expressions = other.eval_expressions;
order = other.order;
text_match_buckets = other.text_match_buckets;
geopoint = other.geopoint;
Expand Down
25 changes: 24 additions & 1 deletion include/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ struct filter_node_t {
filter_node_t* left = nullptr;
filter_node_t* right = nullptr;

filter_node_t(filter filter_exp)
filter_node_t() = default;

explicit filter_node_t(filter filter_exp)
: filter_exp(std::move(filter_exp)),
isOperator(false),
left(nullptr),
Expand All @@ -85,4 +87,25 @@ struct filter_node_t {
delete left;
delete right;
}

filter_node_t& operator=(filter_node_t&& obj) noexcept {
if (&obj == this) {
return *this;
}

if (obj.isOperator) {
isOperator = true;
filter_operator = obj.filter_operator;
left = obj.left;
right = obj.right;

obj.left = nullptr;
obj.right = nullptr;
} else {
isOperator = false;
filter_exp = obj.filter_exp;
}

return *this;
}
};
2 changes: 2 additions & 0 deletions include/filter_result_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class filter_result_iterator_t {
/// iterator reaching it's end. (is_valid would be false in both these cases)
uint32_t approx_filter_ids_length = 0;

filter_result_iterator_t() = default;

explicit filter_result_iterator_t(uint32_t* ids, const uint32_t& ids_count);

explicit filter_result_iterator_t(const std::string collection_name,
Expand Down
2 changes: 1 addition & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ class Index {
std::array<spp::sparse_hash_map<uint32_t, int64_t>*, 3> field_values,
const std::vector<size_t>& geopoint_indices, uint32_t seq_id,
const std::map<basic_string<char>, reference_filter_result_t>& references,
size_t filter_index,
std::vector<uint32_t>& filter_indexes,
int64_t max_field_match_score,
int64_t* scores,
int64_t& match_score_index, float vector_distance = 0,
Expand Down
105 changes: 65 additions & 40 deletions src/collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ struct sort_fields_guard_t {

~sort_fields_guard_t() {
for(auto& sort_by_clause: sort_fields_std) {
delete sort_by_clause.eval.filter_tree_root;
if(sort_by_clause.eval.ids) {
delete [] sort_by_clause.eval.ids;
sort_by_clause.eval.ids = nullptr;
sort_by_clause.eval.size = 0;
for (auto& eval_ids: sort_by_clause.eval.eval_ids_vec) {
delete [] eval_ids;
}
delete [] sort_by_clause.eval.filter_trees;
}
}
};
Expand Down Expand Up @@ -774,7 +772,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
const bool is_vector_query,
const bool is_group_by_query) const {

size_t num_sort_expressions = 0;
uint32_t eval_sort_count = 0;

for(size_t i = 0; i < sort_fields.size(); i++) {
const sort_by& _sort_field = sort_fields[i];
Expand Down Expand Up @@ -813,6 +811,37 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_fields_std.emplace_back(ref_sort_field_std);
}

continue;
} else if (_sort_field.name == sort_field_const::eval) {
sort_by sort_field_std(sort_field_const::eval, _sort_field.order);

auto const& count = _sort_field.eval_expressions.size();
sort_field_std.eval.filter_trees = new filter_node_t[count];
std::unique_ptr<filter_node_t []> filter_trees_guard(sort_field_std.eval.filter_trees);

for (uint32_t j = 0; j < count; j++) {
auto const& filter_exp = _sort_field.eval_expressions[j];
if (filter_exp.empty()) {
return Option<bool>(400, "The eval expression in sort_by is empty.");
}

filter_node_t* filter_tree_root = nullptr;
Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
store, "", filter_tree_root);
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);

if (!parse_filter_op.ok()) {
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
}

sort_field_std.eval.filter_trees[j] = std::move(*filter_tree_root);
}

eval_sort_count++;
sort_field_std.eval_expressions = _sort_field.eval_expressions;
sort_field_std.eval.scores = _sort_field.eval.scores;
sort_fields_std.emplace_back(sort_field_std);
filter_trees_guard.release();
continue;
}

Expand Down Expand Up @@ -843,23 +872,6 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
sort_field_std.name = actual_field_name;
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);

} else if(actual_field_name == sort_field_const::eval) {
const std::string& filter_exp = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() - paran_start -
2);
if(filter_exp.empty()) {
return Option<bool>(400, "The eval expression in sort_by is empty.");
}

Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
store, "", sort_field_std.eval.filter_tree_root);
if(!parse_filter_op.ok()) {
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
}

sort_field_std.name = actual_field_name;
num_sort_expressions++;

} else {
if(field_it == search_schema.end()) {
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
Expand Down Expand Up @@ -1063,7 +1075,7 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
return Option<bool>(422, message);
}

if(num_sort_expressions > 1) {
if(eval_sort_count > 1) {
std::string message = "Only one sorting eval expression is allowed.";
return Option<bool>(422, message);
}
Expand Down Expand Up @@ -1106,22 +1118,6 @@ Option<bool> Collection::validate_and_standardize_sort_fields_with_lock(const st
sort_field_std.name = actual_field_name;
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);

} else if(actual_field_name == sort_field_const::eval) {
const std::string& filter_exp = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() - paran_start -
2);
if(filter_exp.empty()) {
return Option<bool>(400, "The eval expression in sort_by is empty.");
}

Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
store, "", sort_field_std.eval.filter_tree_root);
if(!parse_filter_op.ok()) {
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
}

sort_field_std.name = actual_field_name;

} else {
if(field_it == search_schema.end()) {
std::string error = "Could not find a field named `" + actual_field_name + "` in the schema for sorting.";
Expand Down Expand Up @@ -1243,6 +1239,35 @@ Option<bool> Collection::validate_and_standardize_sort_fields_with_lock(const st

sort_field_std.name = actual_field_name;
}
} else if (sort_field.name == sort_field_const::eval) {
auto const& count = sort_field.eval_expressions.size();
sort_field_std.eval.filter_trees = new filter_node_t[count];
std::unique_ptr<filter_node_t []> filter_trees_guard(sort_field_std.eval.filter_trees);

for (uint32_t j = 0; j < count; j++) {
auto const& filter_exp = sort_field.eval_expressions[j];
if (filter_exp.empty()) {
return Option<bool>(400, "The eval expression in sort_by is empty.");
}

filter_node_t* filter_tree_root = nullptr;
Option<bool> parse_filter_op = filter::parse_filter_query(filter_exp, search_schema,
store, "", filter_tree_root);
std::unique_ptr<filter_node_t> filter_tree_root_guard(filter_tree_root);

if (!parse_filter_op.ok()) {
return Option<bool>(parse_filter_op.code(), "Error parsing eval expression in sort_by clause.");
}

sort_field_std.eval.filter_trees[j] = std::move(*filter_tree_root);
}

sort_field_std.name = sort_field.name;
sort_field_std.eval_expressions = sort_field.eval_expressions;
sort_field_std.eval.scores = sort_field.eval.scores;
sort_fields_std.emplace_back(sort_field_std);
filter_trees_guard.release();
continue;
}

if (sort_field_std.name != sort_field_const::text_match && sort_field_std.name != sort_field_const::eval &&
Expand Down

0 comments on commit a4ab04c

Please sign in to comment.