Skip to content

Commit

Permalink
Adding Rate limiter parameters to VectorStore add function (#2578)
Browse files Browse the repository at this point in the history
Adding Rate limiter parameters to VectorStore add function
  • Loading branch information
adolkhan committed Sep 8, 2023
1 parent d487171 commit 48658fd
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 29 deletions.
17 changes: 11 additions & 6 deletions deeplake/core/vectorstore/deeplake_vectorstore.py
Expand Up @@ -14,6 +14,8 @@
import deeplake
from deeplake.constants import (
DEFAULT_VECTORSTORE_TENSORS,
MAX_BYTES_PER_MINUTE,
TARGET_BYTE_SIZE,
)
from deeplake.core.vectorstore import utils
from deeplake.core.vectorstore.vector_search import vector_search
Expand Down Expand Up @@ -157,7 +159,11 @@ def add(
embedding_data: Optional[Union[List, List[List]]] = None,
embedding_tensor: Optional[Union[str, List[str]]] = None,
return_ids: bool = False,
ingestion_batch_size: Optional[int] = None,
rate_limiter: Dict = {
"enabled": False,
"bytes_per_minute": MAX_BYTES_PER_MINUTE,
},
batch_byte_size: int = TARGET_BYTE_SIZE,
**tensors,
) -> Optional[List[str]]:
"""Adding elements to deeplake vector store.
Expand Down Expand Up @@ -226,7 +232,8 @@ def add(
embedding_data (Optional[List]): Data to be converted into embeddings using the provided ``embedding_function``. Defaults to None.
embedding_tensor (Optional[str]): Tensor where results from the embedding function will be stored. If None, the embedding tensor is automatically inferred (when possible). Defaults to None.
return_ids (bool): Whether to return added ids as an ouput of the method. Defaults to False.
ingestion_batch_size (int): Batch size to use for parallel ingestion. Defaults to 1000. Overrides the ``ingestion_batch_size`` specified when initializing the Vector Store.
rate_limiter (Dict): Rate limiter configuration. Defaults to ``{"enabled": False, "bytes_per_minute": MAX_BYTES_PER_MINUTE}``.
batch_byte_size (int): Batch size to use for parallel ingestion. Defaults to ``TARGET_BYTE_SIZE``.
**tensors: Keyword arguments where the key is the tensor name, and the value is a list of samples that should be uploaded to that tensor.
Returns:
Expand Down Expand Up @@ -280,10 +287,8 @@ def add(
embedding_function=embedding_function,
embedding_data=embedding_data,
embedding_tensor=embedding_tensor,
ingestion_batch_size=ingestion_batch_size or self.ingestion_batch_size,
num_workers=0,
total_samples_processed=0,
logger=logger,
batch_byte_size=batch_byte_size,
rate_limiter=rate_limiter,
)

if self.verbose:
Expand Down
106 changes: 83 additions & 23 deletions deeplake/core/vectorstore/vector_search/dataset/dataset.py
Expand Up @@ -395,6 +395,8 @@ def extend(
embedding_tensor: Union[str, List[str]],
processed_tensors: Dict[str, Union[List[Any], np.ndarray]],
dataset: deeplake.core.dataset.Dataset,
batch_byte_size: int,
rate_limiter: Dict,
):
"""
Function to extend the dataset with new data.
Expand All @@ -405,6 +407,8 @@ def extend(
embedding_tensor (Union[str, List[str]]): Name of the tensor(s) to store the embedding data.
processed_tensors (Dict[str, List[Any]]): Dictionary of tensors to be added to the dataset.
dataset (deeplake.core.dataset.Dataset): Dataset to be extended.
batch_byte_size (int): Batch size to use for parallel ingestion.
rate_limiter (Dict): Rate limiter configuration.
Raises:
IncorrectEmbeddingShapeError: If embeding function shapes is incorrect.
Expand All @@ -415,27 +419,16 @@ def extend(
for func, data, tensor in zip(
embedding_function, embedding_data, embedding_tensor
):
data_batched = chunk_by_bytes(data, target_byte_size=TARGET_BYTE_SIZE)

# Calculate the number of batches you can send each minute
batches_per_minute = MAX_BYTES_PER_MINUTE / TARGET_BYTE_SIZE

# Calculate sleep time in seconds between batches
sleep_time = 60 / batches_per_minute

data_iterator = data_iteratot_factory(
data, func, batch_byte_size, rate_limiter
)
embedded_data = []

for data_i in tqdm(
data_batched, total=len(data_batched), desc="Creating embedding data"
for data in tqdm(
data_iterator, total=len(data_iterator), desc="creating embeddings"
):
start = time.time()
embedded_data.append(func(data_i))
end = time.time()
if func.__module__ == "langchain.embeddings.openai":
# we need to take into account the time spent on openai call
diff = sleep_time - (end - start)
if diff > 0:
time.sleep(diff)
embedded_data.append(data)

try:
return_embedded_data = np.vstack(embedded_data).astype(dtype=np.float32)
except ValueError:
Expand All @@ -446,7 +439,74 @@ def extend(

processed_tensors[tensor] = return_embedded_data

dataset.extend(processed_tensors)
dataset.extend(processed_tensors, progressbar=True)


class DataIterator:
def __init__(self, data, func, batch_byte_size):
self.data = chunk_by_bytes(data, batch_byte_size)
self.data_itr = iter(self.data)
self.index = 0
self.func = func

def __iter__(self):
return self

def __next__(self):
if self.index >= len(self.data):
raise StopIteration
batch = next(self.data_itr)
batch = self.func(batch)
self.index += 1
return batch

def __len__(self):
return len(self.data)


class RateLimitedDataIterator:
def __init__(self, data, func, batch_byte_size, rate_limiter):
self.data = chunk_by_bytes(data, batch_byte_size)
self.data_iter = iter(self.data)
self.index = 0
self.rate_limiter = rate_limiter
self.bytes_per_minute = rate_limiter["bytes_per_minute"]
self.target_byte_size = batch_byte_size
self.func = func

def __iter__(self):
return self

def __next__(self):
if self.index >= len(self.data):
raise StopIteration
batch = next(self.data_iter)
self.index += 1
# Calculate the number of batches you can send each minute
batches_per_minute = self.bytes_per_minute / self.target_byte_size

# Calculate sleep time in seconds between batches
sleep_time = 60 / batches_per_minute

start = time.time()
batch = self.func(batch)
end = time.time()

# we need to take into account the time spent on openai call
diff = sleep_time - (end - start)
if diff > 0:
time.sleep(diff)
return batch

def __len__(self):
return len(self.data)


def data_iteratot_factory(data, func, batch_byte_size, rate_limiter):
if rate_limiter["enabled"]:
return RateLimitedDataIterator(data, func, batch_byte_size, rate_limiter)
else:
return DataIterator(data, func, batch_byte_size)


def extend_or_ingest_dataset(
Expand All @@ -455,10 +515,8 @@ def extend_or_ingest_dataset(
embedding_function,
embedding_tensor,
embedding_data,
ingestion_batch_size,
num_workers,
total_samples_processed,
logger,
batch_byte_size,
rate_limiter,
):
# TODO: Add back the old logic with checkpointing after indexing is fixed
extend(
Expand All @@ -467,6 +525,8 @@ def extend_or_ingest_dataset(
embedding_tensor,
processed_tensors,
dataset,
batch_byte_size,
rate_limiter,
)


Expand Down
Expand Up @@ -11,6 +11,7 @@
from deeplake.constants import (
DEFAULT_VECTORSTORE_DEEPLAKE_PATH,
DEFAULT_VECTORSTORE_TENSORS,
TARGET_BYTE_SIZE,
)
from deeplake.tests.common import requires_libdeeplake
from deeplake.constants import MAX_BYTES_PER_MINUTE
Expand Down Expand Up @@ -397,6 +398,8 @@ def mock_embedding_function(text):
embedding_tensor=["embedding"],
processed_tensors=processed_tensors,
dataset=dataset,
batch_byte_size=TARGET_BYTE_SIZE,
rate_limiter={"enabled": True, "bytes_per_minute": MAX_BYTES_PER_MINUTE},
)
end_time = time.time()

Expand Down

0 comments on commit 48658fd

Please sign in to comment.