Skip to content

Commit

Permalink
Fix default metric and virtual tensor logic (#2616)
Browse files Browse the repository at this point in the history
* Fixes so that vector store works with existing indexes

* tiny fix

* Ignore chunk_engine attribute for virtual tensors.

* Fixed black.

* fixes

* Addressed review comments.

* Fixed bug and added basic index tests

* test tweak

---------

Co-authored-by: Sasun Hambardzumyan <xustup@gmail.com>
  • Loading branch information
istranic and khustup committed Sep 26, 2023
1 parent 1fad5ae commit 1e08b0e
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 54 deletions.
6 changes: 5 additions & 1 deletion deeplake/core/dataset/deeplake_query_tensor.py
Expand Up @@ -108,7 +108,7 @@ def htype(self):

@htype.setter
def htype(self, value):
raise NotImplementedError("htype of a query tensor cannot be set.")
raise NotImplementedError("htype of a virtual tensor cannot be set.")

@property
def sample_compression(self):
Expand All @@ -133,6 +133,10 @@ def max_shape(self):
def min_shape(self):
return self.indra_tensor.min_shape

@property
def chunk_engine(self):
raise NotImplementedError("Virtual tensor does not have chunk engine.")

@property
def shape(self):
if (
Expand Down
19 changes: 10 additions & 9 deletions deeplake/core/tensor.py
Expand Up @@ -996,16 +996,17 @@ def data(self, aslist: bool = False, fetch_chunks: bool = False) -> Any:
)
value = parse_mesh_to_dict(full_arr, self.sample_info)
return value
elif hasattr(self, "chunk_engine"):
return {
"value": self.chunk_engine.numpy(
index=self.index, aslist=aslist, fetch_chunks=fetch_chunks
),
}
else:
return {
"value": self.numpy(aslist=aslist, fetch_chunks=fetch_chunks),
}
try:
return {
"value": self.chunk_engine.numpy(
index=self.index, aslist=aslist, fetch_chunks=fetch_chunks
),
}
except NotImplementedError:
return {
"value": self.numpy(aslist=aslist, fetch_chunks=fetch_chunks),
}

def tobytes(self) -> bytes:
"""Returns the bytes of the tensor.
Expand Down
10 changes: 5 additions & 5 deletions deeplake/core/vectorstore/deeplake_vectorstore.py
Expand Up @@ -174,10 +174,10 @@ def __init__(
self._exec_option = exec_option
self.verbose = verbose
self.tensor_params = tensor_params
self.index_created = False
self.distance_metric_index = False
if utils.index_used(self.exec_option):
index.index_cache_cleanup(self.dataset)
self.index_created = index.validate_and_create_vector_index(
self.distance_metric_index = index.validate_and_create_vector_index(
dataset=self.dataset,
index_params=self.index_params,
regenerate_index=False,
Expand Down Expand Up @@ -349,7 +349,7 @@ def add(

if utils.index_used(self.exec_option):
index.index_cache_cleanup(self.dataset)
self.index_created = index.validate_and_create_vector_index(
self.distance_metric_index = index.validate_and_create_vector_index(
dataset=self.dataset,
index_params=self.index_params,
regenerate_index=index_regeneration,
Expand Down Expand Up @@ -496,9 +496,9 @@ def search(
embedding_function=embedding_function or self.embedding_function,
)

if self.index_created:
if self.distance_metric_index:
distance_metric = index.get_index_distance_metric_from_params(
logger, self.index_params, distance_metric
logger, self.distance_metric_index, distance_metric
)

distance_metric = distance_metric or DEFAULT_VECTORSTORE_DISTANCE_METRIC
Expand Down
38 changes: 38 additions & 0 deletions deeplake/core/vectorstore/test_deeplake_vectorstore.py
Expand Up @@ -15,13 +15,15 @@
from deeplake.tests.common import requires_libdeeplake
from deeplake.constants import (
DEFAULT_VECTORSTORE_TENSORS,
DEFAULT_VECTORSTORE_DISTANCE_METRIC,
)
from deeplake.constants import MB
from deeplake.util.exceptions import (
IncorrectEmbeddingShapeError,
TensorDoesNotExistError,
DatasetHandlerError,
)
from deeplake.core.vectorstore.vector_search.indra.index import METRIC_TO_INDEX_METRIC
from deeplake.core.vectorstore.vector_search import dataset as dataset_utils
from deeplake.cli.auth import login, logout
from click.testing import CliRunner
Expand Down Expand Up @@ -509,6 +511,42 @@ def filter_fn(x):
assert len(result) == 4


@requires_libdeeplake
def test_index_basic(local_path, hub_cloud_dev_token):
# Start by testing behavior without an index
vector_store = VectorStore(
path=local_path,
overwrite=True,
token=hub_cloud_dev_token,
)

vector_store.add(embedding=embeddings, text=texts, metadata=metadatas)

assert vector_store.distance_metric_index is None

# Then test behavior when index is added
vector_store = VectorStore(
path=local_path, token=hub_cloud_dev_token, index_params={"threshold": 1}
)

assert (
vector_store.distance_metric_index
== METRIC_TO_INDEX_METRIC[DEFAULT_VECTORSTORE_DISTANCE_METRIC]
)

# Then test behavior when index is added previously and the dataset is reloaded
vector_store = VectorStore(path=local_path, token=hub_cloud_dev_token)

assert (
vector_store.distance_metric_index
== METRIC_TO_INDEX_METRIC[DEFAULT_VECTORSTORE_DISTANCE_METRIC]
)

# Check that distance metric cannot be specified when there is an index
with pytest.warns(None):
vector_store.search(embedding=query_embedding, distance_metric="blabla")


@pytest.mark.slow
@requires_libdeeplake
@pytest.mark.parametrize("distance_metric", ["L1", "L2", "COS", "MAX"])
Expand Down
106 changes: 67 additions & 39 deletions deeplake/core/vectorstore/vector_search/indra/index.py
Expand Up @@ -5,17 +5,28 @@

METRIC_TO_INDEX_METRIC = {
"L2": "l2_norm",
"L1": "l1_norm",
"COS": "cosine_similarity",
}


def get_index_distance_metric_from_params(logger, index_params, distance_metric):
def get_index_distance_metric_from_params(
logger, distance_metric_index, distance_metric
):
if distance_metric:
logger.warning(
f"Specifying `distance_metric` for a Vector Store with an index is not supported; `distance_metric` was specified as: `{distance_metric}`. "
f"The search will be performed using the distance metric from index_params['distance_metric']: `{index_params['distance_metric']}`"
f"The search will be performed using the distance metric from the index: `{distance_metric_index}`"
)
return index_params.get("distance_metric", "L2")

for key in METRIC_TO_INDEX_METRIC:
if METRIC_TO_INDEX_METRIC[key] == distance_metric_index:
return key

raise ValueError(
f"Invalid distance metric in the index: {distance_metric_index}. "
f"Valid options are: {', '.join([e for e in list(METRIC_TO_INDEX_METRIC.keys())])}"
)


def get_index_metric(metric):
Expand Down Expand Up @@ -87,47 +98,64 @@ def index_cache_cleanup(dataset):


def validate_and_create_vector_index(dataset, index_params, regenerate_index=False):
"""
Validate if the index is present in the dataset and create one if not present but required based on the specified index_params.
Currently only supports 1 index per dataset.
Returns: Distance metric for the index. If None, then no index is available.
TODO: Update to support multiple indexes per dataset, only once the TQL parser also supports that
"""

threshold = index_params.get("threshold", -1)
if threshold <= 0:
return False
elif len(dataset) < threshold:
return False

index_regen = False
below_threshold = threshold <= 0 or len(dataset) < threshold

tensors = dataset.tensors

# TODO: BRING BACK WHEN IT IS IN USE

# index_regen = False
# Check if regenerate_index is true.
if regenerate_index:
for _, tensor in tensors.items():
is_embedding = utils.is_embedding_tensor(tensor)
has_vdb_indexes = hasattr(tensor.meta, "vdb_indexes")
try:
vdb_index_ids_present = len(tensor.meta.vdb_indexes) > 0
except AttributeError:
vdb_index_ids_present = False

if is_embedding and has_vdb_indexes and vdb_index_ids_present:
tensor._regenerate_vdb_indexes()
index_regen = True
if index_regen:
return
# if regenerate_index:
# for _, tensor in tensors.items():
# is_embedding = utils.is_embedding_tensor(tensor)
# has_vdb_indexes = hasattr(tensor.meta, "vdb_indexes")
# try:
# vdb_index_ids_present = len(tensor.get_vdb_indexes()) > 0
# except AttributeError:
# vdb_index_ids_present = False

# if is_embedding and has_vdb_indexes and vdb_index_ids_present:
# tensor._regenerate_vdb_indexes()
# index_regen = True
# if index_regen:
# return

# Check all tensors from the dataset.
for _, tensor in tensors.items():
is_embedding = utils.is_embedding_tensor(tensor)
vdb_index_absent = len(tensor.meta.get_vdb_index_ids()) == 0
if is_embedding and vdb_index_absent:
try:
distance_str = index_params.get("distance_metric", "L2")
additional_params_dict = index_params.get("additional_params", None)
distance = get_index_metric(distance_str.upper())
if additional_params_dict and len(additional_params_dict) > 0:
param_dict = normalize_additional_params(additional_params_dict)
tensor.create_vdb_index(
"hnsw_1", distance=distance, additional_params=param_dict
)
else:
tensor.create_vdb_index("hnsw_1", distance=distance)
except ValueError as e:
raise e

return True

if is_embedding:
vdb_indexes = tensor.get_vdb_indexes()

if len(vdb_indexes) == 0 and not below_threshold:
try:
distance_str = index_params.get("distance_metric", "L2")
additional_params_dict = index_params.get("additional_params", None)
distance = get_index_metric(distance_str.upper())
if additional_params_dict and len(additional_params_dict) > 0:
param_dict = normalize_additional_params(additional_params_dict)
tensor.create_vdb_index(
"hnsw_1", distance=distance, additional_params=param_dict
)
else:
tensor.create_vdb_index("hnsw_1", distance=distance)

return distance
except ValueError as e:
raise e
elif len(vdb_indexes) > 0:
return vdb_indexes[0]["distance"]

return None

0 comments on commit 1e08b0e

Please sign in to comment.