Skip to content

Commit

Permalink
Fix recommend best_score formula (#2904)
Browse files Browse the repository at this point in the history
* fix recommend best score formula

* fix integration test

* review fixes

* rebase and fix conflicts

* correct proptest comment
  • Loading branch information
coszio committed Nov 2, 2023
1 parent 4700e2a commit 50f3ef1
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 84 deletions.
11 changes: 6 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions lib/collection/src/shards/local_shard_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::collection_manager::segments_searcher::SegmentsSearcher;
use crate::common::stopping_guard::StoppingGuard;
use crate::operations::types::{
CollectionError, CollectionInfo, CollectionResult, CoreSearchRequestBatch, CountRequest,
CountResult, PointRequest, Record, SearchRequestBatch, UpdateResult, UpdateStatus,
CountResult, PointRequest, QueryEnum, Record, SearchRequestBatch, UpdateResult, UpdateStatus,
};
use crate::operations::CollectionUpdateOperations;
use crate::optimizers_builder::DEFAULT_INDEXING_THRESHOLD_KB;
Expand Down Expand Up @@ -76,7 +76,15 @@ impl LocalShard {
.unwrap()
.distance;
let processed_res = vector_res.into_iter().map(|mut scored_point| {
scored_point.score = distance.postprocess_score(scored_point.score);
match req.query {
QueryEnum::Nearest(_) => {
scored_point.score = distance.postprocess_score(scored_point.score);
}
// Don't post-process if we are dealing with custom scoring
QueryEnum::RecommendBestScore(_)
| QueryEnum::Discover(_)
| QueryEnum::Context(_) => {}
};
scored_point
});

Expand Down
1 change: 1 addition & 0 deletions lib/segment/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ rmp-serde = "~1.1"
rand_distr = "0.4.3"
walkdir = "2.4.0"
rstest = "0.18.2"
proptest = "1.3.1"

[target.'cfg(not(target_os = "windows"))'.dev-dependencies]
pprof = { version = "0.12", features = ["flamegraph", "prost-codec"] }
Expand Down
92 changes: 80 additions & 12 deletions lib/segment/src/vector_storage/query/reco_query.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use common::math::scaled_fast_sigmoid;
use common::types::ScoreType;

use super::{Query, TransformInto};
Expand Down Expand Up @@ -61,9 +62,9 @@ fn merge_similarities(
.unwrap_or(ScoreType::NEG_INFINITY);

if max_positive > max_negative {
max_positive
scaled_fast_sigmoid(max_positive)
} else {
-(max_negative * max_negative)
-scaled_fast_sigmoid(max_negative)
}
}

Expand All @@ -75,33 +76,100 @@ impl From<RecoQuery<Vector>> for QueryVector {

#[cfg(test)]
mod test {
use common::math::scaled_fast_sigmoid;
use common::types::ScoreType;
use proptest::prelude::*;
use rstest::rstest;

use super::RecoQuery;
use crate::vector_storage::query::Query;

enum Chosen {
Positive,
Negative,
}

#[rstest]
#[case::higher_positive(vec![42], vec![4], 42.0)]
#[case::higher_negative(vec![4], vec![42], -(42.0 * 42.0))]
#[case::negative_zero(vec![-1], vec![0], 0.0)]
#[case::positive_zero(vec![0], vec![-1], 0.0)]
#[case::both_under_zero(vec![-42], vec![-84], -42.0)]
#[case::both_under_zero_but_negative_is_higher(vec![-84], vec![-42], -(42.0 * 42.0))]
#[case::multiple_with_negative_best(vec![1, 2, 3], vec![4, 5, 6], -(6.0 * 6.0))]
#[case::multiple_with_positive_best(vec![10, 2, 3], vec![4, 5, 6], 10.0)]
#[case::no_input(vec![], vec![], ScoreType::NEG_INFINITY)]
#[case::higher_positive(vec![42], vec![4], Chosen::Positive, 42.0)]
#[case::higher_negative(vec![4], vec![42], Chosen::Negative, 42.0)]
#[case::negative_zero(vec![-1], vec![0], Chosen::Negative, 0.0)]
#[case::positive_zero(vec![0], vec![-1], Chosen::Positive, 0.0)]
#[case::both_under_zero(vec![-42], vec![-84], Chosen::Positive, -42.0)]
#[case::both_under_zero_but_negative_is_higher(vec![-84], vec![-42], Chosen::Negative, -42.0)]
#[case::multiple_with_negative_best(vec![1, 2, 3], vec![4, 5, 6], Chosen::Negative, 6.0)]
#[case::multiple_with_positive_best(vec![10, 2, 3], vec![4, 5, 6], Chosen::Positive, 10.0)]
fn score_query(
#[case] positives: Vec<isize>,
#[case] negatives: Vec<isize>,
#[case] chosen: Chosen,
#[case] expected: ScoreType,
) {
let query = RecoQuery::new(positives, negatives);

let dummy_similarity = |x: &isize| *x as ScoreType;

let positive_transformation = scaled_fast_sigmoid;
let negative_transformation = |x| -scaled_fast_sigmoid(x);

let score = query.score_by(dummy_similarity);

assert_eq!(score, expected);
match chosen {
Chosen::Positive => {
assert_eq!(score, positive_transformation(expected));
}
Chosen::Negative => {
assert_eq!(score, negative_transformation(expected));
}
}
}

proptest! {
/// Checks that the negative-chosen scores invert the order of the candidates
#[test]
fn correct_negative_order(a in -100f32..=100f32, b in -100f32..=100f32) {
let dummy_similarity = |x: &f32| *x as ScoreType;

let ordering_before = dummy_similarity(&a).total_cmp(&dummy_similarity(&b));

let query_a = RecoQuery::new(vec![], vec![a]);
let query_b = RecoQuery::new(vec![], vec![b]);

let ordering_after = query_a.score_by(dummy_similarity).total_cmp(&query_b.score_by(dummy_similarity));

if ordering_before == std::cmp::Ordering::Equal {
assert_eq!(ordering_before, ordering_after);
} else {
assert_ne!(ordering_before, ordering_after)
}
}

/// Checks that the positive-chosen scores preserve the order of the candidates
#[test]
fn correct_positive_order(a in -100f32..=100f32, b in -100f32..=100f32) {
let dummy_similarity = |x: &f32| *x as ScoreType;

let ordering_before = dummy_similarity(&a).total_cmp(&dummy_similarity(&b));

let query_a = RecoQuery::new(vec![a], vec![]);
let query_b = RecoQuery::new(vec![b], vec![]);

let ordering_after = query_a.score_by(dummy_similarity).total_cmp(&query_b.score_by(dummy_similarity));

assert_eq!(ordering_before, ordering_after);
}

/// Guarantees that the point that was chosen from positive is always preferred on
/// the candidate list over a point that was chosen from negatives
#[test]
fn correct_positive_and_negative_order(p in -100f32..=100f32, n in -100f32..=100f32) {
let dummy_similarity = |x: &f32| *x as ScoreType;

let query_p = RecoQuery::new(vec![p], vec![]);
let query_n = RecoQuery::new(vec![], vec![n]);

let ordering = query_p.score_by(dummy_similarity).total_cmp(&query_n.score_by(dummy_similarity));

assert_ne!(ordering, std::cmp::Ordering::Less);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn random_reco_query<R: Rng + ?Sized>(
sampler: &mut impl Iterator<Item = f32>,
) -> QueryVector {
let num_positives: usize = rnd.gen_range(0..MAX_EXAMPLES);
let num_negatives: usize = rnd.gen_range(0..MAX_EXAMPLES);
let num_negatives: usize = rnd.gen_range(1..MAX_EXAMPLES);

let positives = (0..num_positives)
.map(|_| sampler.take(DIMS).collect_vec().into())
Expand Down Expand Up @@ -236,7 +236,7 @@ fn scoring_equivalency(
let quantized_vectors = quantized_vectors.as_ref().map(|q| q.borrow());

let attempts = 50;
for _i in 0..attempts {
for i in 0..attempts {
let query = random_query(&query_variant, &mut rng, &mut sampler);

let raw_scorer = new_raw_scorer(
Expand All @@ -251,15 +251,18 @@ fn scoring_equivalency(
let other_scorer = match &quantized_vectors {
Some(quantized_storage) => quantized_storage
.raw_scorer(
query,
query.clone(),
id_tracker.deleted_point_bitslice(),
other_storage.deleted_vector_bitslice(),
&is_stopped,
)
.unwrap(),
None => {
new_raw_scorer(query, &other_storage, id_tracker.deleted_point_bitslice()).unwrap()
}
None => new_raw_scorer(
query.clone(),
&other_storage,
id_tracker.deleted_point_bitslice(),
)
.unwrap(),
};

let points =
Expand All @@ -273,8 +276,8 @@ fn scoring_equivalency(
// both calculations are done on raw vectors, so score should be exactly the same
assert_eq!(
raw_scores, other_scores,
"Scorer results are not equal, attempt: {}",
_i
"Scorer results are not equal, attempt: {}, query: {:?}",
i, query
);
} else {
// Quantization is used for the other storage, so score should be similar
Expand Down Expand Up @@ -302,7 +305,7 @@ fn scoring_equivalency(

assert!(
(intersection as f32 / top as f32) >= 0.7, // at least 70% of top 10% results should be shared
"Top results from scorers are not similar, attempt {_i}:
"Top results from scorers are not similar, attempt {i}:
top raw: {raw_top:?},
top other: {other_top:?}
only {intersection} of {top} top results are shared",
Expand Down

0 comments on commit 50f3ef1

Please sign in to comment.