Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

- [WIP]Deeplake BM25 Implementation #2828

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions deeplake/constants.py
Expand Up @@ -329,6 +329,7 @@

DEFAULT_VECTORSTORE_INDEX_PARAMS = {
"threshold": -1,
"bm25": False,
"distance_metric": DEFAULT_VECTORSTORE_DISTANCE_METRIC,
"additional_params": {
"efConstruction": 600,
Expand Down
7 changes: 7 additions & 0 deletions deeplake/core/dataset/dataset.py
Expand Up @@ -549,6 +549,7 @@ def __getitem__(
enabled_tensors=self.enabled_tensors,
view_base=self._view_base or self,
libdeeplake_dataset=self.libdeeplake_dataset,
index_params=self.index_params,
)
elif "/" in item:
splt = posixpath.split(item)
Expand Down Expand Up @@ -595,6 +596,7 @@ def __getitem__(
enabled_tensors=enabled_tensors,
view_base=self._view_base or self,
libdeeplake_dataset=self.libdeeplake_dataset,
index_params=self.index_params,
)
elif isinstance(item, tuple) and len(item) and isinstance(item[0], str):
ret = self
Expand Down Expand Up @@ -624,6 +626,7 @@ def __getitem__(
enabled_tensors=self.enabled_tensors,
view_base=self._view_base or self,
libdeeplake_dataset=self.libdeeplake_dataset,
index_params=self.index_params,
)
else:
raise InvalidKeyTypeError(item)
Expand Down Expand Up @@ -2904,6 +2907,7 @@ def parent(self):
path=self.path,
link_creds=self.link_creds,
libdeeplake_dataset=self.libdeeplake_dataset,
index_params=self.index_params,
)
self.storage.autoflush = autoflush
return ds
Expand All @@ -2927,6 +2931,7 @@ def root(self):
link_creds=self.link_creds,
view_base=self._view_base,
libdeeplake_dataset=self.libdeeplake_dataset,
index_params=self.index_params,
)
self.storage.autoflush = autoflush
return ds
Expand All @@ -2950,6 +2955,7 @@ def no_view_dataset(self):
pad_tensors=self._pad_tensors,
enabled_tensors=self.enabled_tensors,
libdeeplake_dataset=self.libdeeplake_dataset,
index_params=self.index_params,
)

def _create_group(self, name: str) -> "Dataset":
Expand Down Expand Up @@ -4824,6 +4830,7 @@ def max_view(self):
pad_tensors=True,
enabled_tensors=self.enabled_tensors,
libdeeplake_dataset=self.libdeeplake_dataset,
index_params=self.index_params,
)

def random_split(self, lengths: Sequence[Union[int, float]]):
Expand Down
122 changes: 102 additions & 20 deletions deeplake/core/index_maintenance.py
Expand Up @@ -28,6 +28,16 @@ def is_embedding_tensor(tensor):
or tensor.key in valid_names
)

def is_text_tensor(tensor):
"""Check if a tensor is a text tensor."""

valid_names = ["text"]

return (
tensor.htype == "text"
or tensor.meta.name in valid_names
or tensor.key in valid_names
)

def validate_embedding_tensor(tensor):
"""Check if a tensor is an embedding tensor."""
Expand All @@ -40,6 +50,17 @@ def validate_embedding_tensor(tensor):
or tensor.key in valid_names
)

def validate_text_tensor(tensor):
"""Check if a tensor is an embedding tensor."""

valid_names = ["text"]

return (
tensor.meta.name in valid_names and
tensor.htype == "text" and
tensor.key in valid_names
)


def fetch_embedding_tensor(dataset):
tensors = dataset.tensors
Expand All @@ -48,8 +69,15 @@ def fetch_embedding_tensor(dataset):
return tensor
return None

def fetch_text_tensor(dataset):
tensors = dataset.tensors
for _, tensor in tensors.items():
if validate_text_tensor(tensor):
return tensor
return None


def index_exists(dataset):
def index_exists_for_embedding_tensor(dataset):
"""Check if the Index already exists."""
emb_tensor = fetch_embedding_tensor(dataset)
if emb_tensor is not None:
Expand All @@ -61,6 +89,18 @@ def index_exists(dataset):
else:
return False

def index_exists_for_text_tensor(dataset):
"""Check if the Index already exists."""
text_tensor = fetch_text_tensor(dataset)
if text_tensor is not None:
vdb_indexes = text_tensor.fetch_vdb_indexes()
if len(vdb_indexes) == 0:
return False
else:
return True
else:
return False


def index_used(exec_option):
"""Check if the index is used for the exec_option"""
Expand Down Expand Up @@ -110,7 +150,7 @@ def check_index_params(self):


def index_operation_type_dataset(self, num_rows, changed_data_len):
if not index_exists(self):
if not index_exists_for_embedding_tensor(self):
if self.index_params is None:
return INDEX_OP_TYPE.NOOP
threshold = self.index_params.get("threshold", -1)
Expand Down Expand Up @@ -183,13 +223,14 @@ def check_vdb_indexes(dataset):
def _incr_maintenance_vdb_indexes(tensor, indexes, index_operation):
try:
is_embedding = tensor.htype == "embedding"
is_text = tensor.htype == "text"
has_vdb_indexes = hasattr(tensor.meta, "vdb_indexes")
try:
vdb_index_ids_present = len(tensor.meta.vdb_indexes) > 0
except AttributeError:
vdb_index_ids_present = False

if is_embedding and has_vdb_indexes and vdb_index_ids_present:
if is_embedding or is_text and has_vdb_indexes and vdb_index_ids_present:
for vdb_index in tensor.meta.vdb_indexes:
tensor.update_vdb_index(
operation_kind=index_operation,
Expand All @@ -204,44 +245,71 @@ def index_operation_vectorstore(self):
if not index_used(self.exec_option):
return None

emb_tensor = fetch_embedding_tensor(self.dataset)

if index_exists(self.dataset) and check_index_params(self):
return emb_tensor.get_vdb_indexes()[0]["distance"]

threshold = self.index_params.get("threshold", -1)
below_threshold = threshold < 0 or len(self.dataset) < threshold
if below_threshold:
return None

if not check_index_params(self):
try:
vdb_indexes = emb_tensor.get_vdb_indexes()
for vdb_index in vdb_indexes:
emb_tensor.delete_vdb_index(vdb_index["id"])
except Exception as e:
raise Exception(f"An error occurred while removing VDB indexes: {e}")
bm25 = self.index_params.get("bm25", False)
print("BM25: ", bm25)
if bm25:
txt_tensor = fetch_text_tensor(self.dataset)

emb_tensor = fetch_embedding_tensor(self.dataset)

# TODO have to revisit it later.
if index_exists_for_embedding_tensor(self.dataset) and check_index_params(self):
return emb_tensor.get_vdb_indexes()[0]["distance"]

if bm25 and index_exists_for_text_tensor(self.dataset):
return txt_tensor.get_vdb_indexes()[0]

# if not check_index_params(self):
# try:
# vdb_indexes = tensor.get_vdb_indexes()
# for vdb_index in vdb_indexes:
# tensor.delete_vdb_index(vdb_index["id"])
# except Exception as e:
# raise Exception(f"An error occurred while removing VDB indexes: {e}")


if bm25:
print("Creating BM25 index")
txt_tensor.create_vdb_index("bm25")

distance_str = self.index_params.get("distance_metric", "COS")
additional_params_dict = self.index_params.get("additional_params", None)
distance = get_index_metric(distance_str.upper())
if additional_params_dict and len(additional_params_dict) > 0:
param_dict = normalize_additional_params(additional_params_dict)
print("Creating HNSW index")
emb_tensor.create_vdb_index(
"hnsw_1", distance=distance, additional_params=param_dict
)
else:
print("Creating HNSW index")
emb_tensor.create_vdb_index("hnsw_1", distance=distance)
return distance


def index_operation_dataset(self, dml_type, rowids):
if self.index_params is None:
return

bm25 = self.index_params.get("bm25", False)
txt_tensor = None
if bm25:
txt_tensor = fetch_text_tensor(self)

emb_tensor = fetch_embedding_tensor(self)
if emb_tensor is None:
if emb_tensor and txt_tensor is None:
return

num_rows = txt_tensor.chunk_engine.num_samples if txt_tensor is not None else emb_tensor.chunk_engine.num_samples

index_operation_type = index_operation_type_dataset(
self,
emb_tensor.chunk_engine.num_samples,
num_rows,
len(rowids),
)

Expand All @@ -254,13 +322,23 @@ def index_operation_dataset(self, dml_type, rowids):
):
if index_operation_type == INDEX_OP_TYPE.REGENERATE_INDEX:
try:
vdb_indexes = emb_tensor.get_vdb_indexes()
for vdb_index in vdb_indexes:
emb_tensor.delete_vdb_index(vdb_index["id"])
if txt_tensor is not None:
print("Regenerating BM25 index for text tensor")
vdb_indexes = txt_tensor.get_vdb_indexes()
for vdb_index in vdb_indexes:
txt_tensor.delete_vdb_index(vdb_index["id"])
else:
vdb_indexes = emb_tensor.get_vdb_indexes()
for vdb_index in vdb_indexes:
emb_tensor.delete_vdb_index(vdb_index["id"])
except Exception as e:
raise Exception(
f"An error occurred while regenerating VDB indexes: {e}"
)
if txt_tensor is not None:
print("Creating BM25 index")
txt_tensor.create_vdb_index("bm25_1")

distance_str = self.index_params.get("distance_metric", "COS")
additional_params_dict = self.index_params.get("additional_params", None)
distance = get_index_metric(distance_str.upper())
Expand All @@ -272,6 +350,10 @@ def index_operation_dataset(self, dml_type, rowids):
else:
emb_tensor.create_vdb_index("hnsw_1", distance=distance)
elif index_operation_type == INDEX_OP_TYPE.INCREMENTAL_INDEX:
if txt_tensor is not None:
print("Incremental maintenance of BM25 index")
_incr_maintenance_vdb_indexes(txt_tensor, rowids, dml_type)

_incr_maintenance_vdb_indexes(emb_tensor, rowids, dml_type)
else:
raise Exception("Unknown index operation")
3 changes: 3 additions & 0 deletions deeplake/core/meta/tensor_meta.py
Expand Up @@ -229,6 +229,9 @@ def __setstate__(self, state: Dict[str, Any]):
if self.htype == "embedding" and not hasattr(self, "vdb_indexes"):
self.vdb_indexes = []
self._required_meta_keys += ("vdb_indexes",)
if self.htype == "text" and not hasattr(self, "vdb_indexes"):
self.vdb_indexes = []
self._required_meta_keys += ("vdb_indexes",)

@property
def nbytes(self):
Expand Down