Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Change l1 method name to l1Norm for painless extensions (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmazanec15 committed Feb 3, 2021
1 parent 9381e7e commit ba5c91d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public L1(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension());
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1distance(q, v));
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) {
* @param inputVector input vector
* @return L1 score
*/
public static float l1distance(float[] queryVector, float[] inputVector) {
public static float l1Norm(float[] queryVector, float[] inputVector) {
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
Expand All @@ -232,7 +232,7 @@ public static float l1distance(float[] queryVector, float[] inputVector) {
* and document vectors
* Example
* "script": {
* "source": "1/(1 + l1distance(params.query_vector, doc[params.field]))",
* "source": "1/(1 + l1Norm(params.query_vector, doc[params.field]))",
* "params": {
* "query_vector": [1, 2, 3.4],
* "field": "my_dense_vector"
Expand All @@ -243,7 +243,7 @@ public static float l1distance(float[] queryVector, float[] inputVector) {
* @param docValues script doc values
* @return L1 score
*/
public static float l1distance(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
return l1distance(toFloat(queryVector), docValues.getValue());
public static float l1Norm(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
return l1Norm(toFloat(queryVector), docValues.getValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues {
}
static_import {
float l2Squared(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
float l1distance(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
float l1Norm(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
float cosineSimilarity(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
float cosineSimilarity(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues, Number) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
}
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,15 @@ public void testCosineSimilarityNormalizedScriptScoreWithNumericField() throws E

// L1 tests
public void testL1ScriptScoreFails() throws Exception {
String source = String.format("1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
String source = String.format("1/(1 + l1Norm([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Request request = buildPainlessScriptRequest(source, 3, getL1TestData());
addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000);
expectThrows(ResponseException.class, () -> client().performRequest(request));
deleteKNNIndex(INDEX_NAME);
}
public void testL1ScriptScore() throws Exception {

String source = String.format("1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
String source = String.format("1/(1 + l1Norm([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Request request = buildPainlessScriptRequest(source, 3, getL1TestData());

Response response = client().performRequest(request);
Expand All @@ -285,7 +285,7 @@ public void testL1ScriptScore() throws Exception {
public void testL1ScriptScoreWithNumericField() throws Exception {

String source = String.format(
"doc['%s'].size() == 0 ? 0 : 1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME);
"doc['%s'].size() == 0 ? 0 : 1/(1 + l1Norm([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME);
Request request = buildPainlessScriptRequest(source, 3, getL1TestData());
addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000);
Response response = client().performRequest(request);
Expand Down

0 comments on commit ba5c91d

Please sign in to comment.