Skip to content

Commit

Permalink
in segcore validate search sparse vector to make sure it is legal
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed May 8, 2024
1 parent 01c2684 commit 7572f03
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
34 changes: 29 additions & 5 deletions internal/core/src/common/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <unistd.h>

#include <cstring>
#include <cmath>
#include <filesystem>
#include <memory>
#include <string>
Expand Down Expand Up @@ -215,28 +216,51 @@ GetCommonPrefix(const std::string& str1, const std::string& str2) {
}

inline knowhere::sparse::SparseRow<float>
CopyAndWrapSparseRow(const void* data, size_t size) {
CopyAndWrapSparseRow(const void* data,
size_t size,
const bool validate = false) {
size_t num_elements =
size / knowhere::sparse::SparseRow<float>::element_size();
knowhere::sparse::SparseRow<float> row(num_elements);
std::memcpy(row.data(), data, size);
// TODO(SPARSE): validate
if (validate) {
AssertInfo(size > 0, "Sparse row data should not be empty");
AssertInfo(
size % knowhere::sparse::SparseRow<float>::element_size() == 0,
"Invalid size for sparse row data");
for (size_t i = 0; i < num_elements; ++i) {
auto element = row[i];
AssertInfo(std::isfinite(element.val),
"Invalid sparse row: NaN or Inf value");
AssertInfo(element.val >= 0, "Invalid sparse row: negative value");
AssertInfo(
element.id < std::numeric_limits<uint32_t>::max(),
"Invalid sparse row: id should be smaller than uint32 max");
if (i > 0) {
AssertInfo(row[i - 1].id < element.id,
"Invalid sparse row: id should be strict ascending");
}
}
}
return row;
}

// Iterable is a list of bytes, each is a byte array representation of a single
// sparse float row. This helper function converts such byte arrays into a list
// of knowhere::sparse::SparseRow<float>. The resulting list is a deep copy of
// the source data.
//
// Here in segcore we validate the sparse row data only for search requests,
// as the insert/upsert data are already validated in go code.
template <typename Iterable>
std::unique_ptr<knowhere::sparse::SparseRow<float>[]>
SparseBytesToRows(const Iterable& rows) {
SparseBytesToRows(const Iterable& rows, const bool validate = false) {
AssertInfo(rows.size() > 0, "at least 1 sparse row should be provided");
auto res =
std::make_unique<knowhere::sparse::SparseRow<float>[]>(rows.size());
for (size_t i = 0; i < rows.size(); ++i) {
res[i] =
std::move(CopyAndWrapSparseRow(rows[i].data(), rows[i].size()));
res[i] = std::move(
CopyAndWrapSparseRow(rows[i].data(), rows[i].size(), validate));
}
return res;
}
Expand Down
3 changes: 2 additions & 1 deletion internal/core/src/query/Plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ ParsePlaceholderGroup(const Plan* plan,
AssertInfo(element.num_of_queries_ > 0, "must have queries");
if (info.type() ==
milvus::proto::common::PlaceholderType::SparseFloatVector) {
element.sparse_matrix_ = SparseBytesToRows(info.values());
element.sparse_matrix_ =
SparseBytesToRows(info.values(), /*validate=*/true);
} else {
auto line_size = info.values().Get(0).size();
if (field_meta.get_sizeof() != line_size) {
Expand Down

0 comments on commit 7572f03

Please sign in to comment.