diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 8623f640d..68cb676f2 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -96,6 +96,19 @@ def __init__( client_options=client_options, ) + def _to_sync_copy(self): + from google.cloud.firestore_v1.client import Client + + if not getattr(self, "_sync_copy", None): + self._sync_copy = Client( + project=self.project, + credentials=self._credentials, + database=self._database, + client_info=self._client_info, + client_options=self._client_options, + ) + return self._sync_copy + @property def _firestore_api(self): """Lazy-loading getter GAPIC Firestore API. diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index 348a6ac45..a4b7ff0bb 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -14,16 +14,16 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" - -from google.cloud.firestore_v1 import _helpers +import abc +from typing import Dict, Union # Types needed only for Type Hints -from google.cloud.firestore_v1.document import DocumentReference - -from typing import Union +from google.api_core import retry as retries # type: ignore +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference -class BaseWriteBatch(object): +class BaseBatch(metaclass=abc.ABCMeta): """Accumulate write operations to be sent in a batch. This has the same set of methods for write operations that @@ -38,9 +38,16 @@ class BaseWriteBatch(object): def __init__(self, client) -> None: self._client = client self._write_pbs = [] + self._document_references: Dict[str, BaseDocumentReference] = {} self.write_results = None self.commit_time = None + def __len__(self): + return len(self._document_references) + + def __contains__(self, reference: BaseDocumentReference): + return reference._document_path in self._document_references + def _add_write_pbs(self, write_pbs: list) -> None: """Add `Write`` protobufs to this transaction. @@ -52,7 +59,13 @@ def _add_write_pbs(self, write_pbs: list) -> None: """ self._write_pbs.extend(write_pbs) - def create(self, reference: DocumentReference, document_data: dict) -> None: + @abc.abstractmethod + def commit(self): + """Sends all accumulated write operations to the server. The details of this + write depend on the implementing class.""" + raise NotImplementedError() + + def create(self, reference: BaseDocumentReference, document_data: dict) -> None: """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -65,11 +78,12 @@ def create(self, reference: DocumentReference, document_data: dict) -> None: creating a document. """ write_pbs = _helpers.pbs_for_create(reference._document_path, document_data) + self._document_references[reference._document_path] = reference self._add_write_pbs(write_pbs) def set( self, - reference: DocumentReference, + reference: BaseDocumentReference, document_data: dict, merge: Union[bool, list] = False, ) -> None: @@ -98,11 +112,12 @@ def set( reference._document_path, document_data ) + self._document_references[reference._document_path] = reference self._add_write_pbs(write_pbs) def update( self, - reference: DocumentReference, + reference: BaseDocumentReference, field_updates: dict, option: _helpers.WriteOption = None, ) -> None: @@ -126,10 +141,11 @@ def update( write_pbs = _helpers.pbs_for_update( reference._document_path, field_updates, option ) + self._document_references[reference._document_path] = reference self._add_write_pbs(write_pbs) def delete( - self, reference: DocumentReference, option: _helpers.WriteOption = None + self, reference: BaseDocumentReference, option: _helpers.WriteOption = None ) -> None: """Add a "change" to delete a document. @@ -146,9 +162,15 @@ def delete( state of the document before applying changes. """ write_pb = _helpers.pb_for_delete(reference._document_path, option) + self._document_references[reference._document_path] = reference self._add_write_pbs([write_pb]) - def _prep_commit(self, retry, timeout): + +class BaseWriteBatch(BaseBatch): + """Base class for a/sync implementations of the `commit` RPC. `commit` is useful + for lower volumes or when the order of write operations is important.""" + + def _prep_commit(self, retry: retries.Retry, timeout: float): """Shared setup for async/sync :meth:`commit`.""" request = { "database": self._client._database_string, diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 7eb5c26b0..e68031ed4 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -37,7 +37,10 @@ from google.cloud.firestore_v1 import __version__ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1.base_document import DocumentSnapshot - +from google.cloud.firestore_v1.bulk_writer import ( + BulkWriter, + BulkWriterOptions, +) from google.cloud.firestore_v1.field_path import render_field_path from typing import ( Any, @@ -278,6 +281,21 @@ def _get_collection_reference(self, collection_id: str) -> BaseCollectionReferen def document(self, *document_path) -> BaseDocumentReference: raise NotImplementedError + def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter: + """Get a BulkWriter instance from this client. + + Args: + :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriterOptions`: + Optional control parameters for the + :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter` returned. + + Returns: + :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter`: + A utility to efficiently create and save many `WriteBatch` instances + to the server. + """ + return BulkWriter(client=self, options=options) + def _document_path_helper(self, *document_path) -> List[str]: """Standardize the format of path to tuple of path segments and strip the database string from path if present. diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index 175805122..a7ad074ba 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -21,7 +21,9 @@ class WriteBatch(BaseWriteBatch): - """Accumulate write operations to be sent in a batch. + """Accumulate write operations to be sent in a batch. Use this over + `BulkWriteBatch` for lower volumes or when the order of operations + within a given batch is important. This has the same set of methods for write operations that :class:`~google.cloud.firestore_v1.document.DocumentReference` does, diff --git a/google/cloud/firestore_v1/bulk_batch.py b/google/cloud/firestore_v1/bulk_batch.py new file mode 100644 index 000000000..bc2f75a38 --- /dev/null +++ b/google/cloud/firestore_v1/bulk_batch.py @@ -0,0 +1,89 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for batch requests to the Google Cloud Firestore API.""" +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_batch import BaseBatch +from google.cloud.firestore_v1.types.firestore import BatchWriteResponse + + +class BulkWriteBatch(BaseBatch): + """Accumulate write operations to be sent in a batch. Use this over + `WriteBatch` for higher volumes (e.g., via `BulkWriter`) and when the order + of operations within a given batch is unimportant. + + Because the order in which individual write operations are applied to the database + is not guaranteed, `batch_write` RPCs can never contain multiple operations + to the same document. If calling code detects a second write operation to a + known document reference, it should first cut off the previous batch and + send it, then create a new batch starting with the latest write operation. + In practice, the [Async]BulkWriter classes handle this. + + This has the same set of methods for write operations that + :class:`~google.cloud.firestore_v1.document.DocumentReference` does, + e.g. :meth:`~google.cloud.firestore_v1.document.DocumentReference.create`. + + Args: + client (:class:`~google.cloud.firestore_v1.client.Client`): + The client that created this batch. + """ + + def __init__(self, client) -> None: + super(BulkWriteBatch, self).__init__(client=client) + + def commit( + self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None + ) -> BatchWriteResponse: + """Writes the changes accumulated in this batch. + + Write operations are not guaranteed to be applied in order and must not + contain multiple writes to any given document. Preferred over `commit` + for performance reasons if these conditions are acceptable. + + Args: + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + :class:`google.cloud.proto.firestore.v1.write.BatchWriteResponse`: + Container holding the write results corresponding to the changes + committed, returned in the same order as the changes were applied to + this batch. An individual write result contains an ``update_time`` + field. + """ + request, kwargs = self._prep_commit(retry, timeout) + + _api = self._client._firestore_api + save_response: BatchWriteResponse = _api.batch_write( + request=request, metadata=self._client._rpc_metadata, **kwargs, + ) + + self._write_pbs = [] + self.write_results = list(save_response.write_results) + + return save_response + + def _prep_commit(self, retry: retries.Retry, timeout: float): + request = { + "database": self._client._database_string, + "writes": self._write_pbs, + "labels": None, + } + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + return request, kwargs diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py new file mode 100644 index 000000000..ad886f81d --- /dev/null +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -0,0 +1,978 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for efficiently writing large amounts of data to the Google Cloud +Firestore API.""" + +import bisect +import collections +import concurrent.futures +import datetime +import enum +import functools +import logging +import time + +from typing import Callable, Dict, List, Optional, Union, TYPE_CHECKING + +from google.rpc import status_pb2 # type: ignore + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch +from google.cloud.firestore_v1.rate_limiter import RateLimiter +from google.cloud.firestore_v1.types.firestore import BatchWriteResponse +from google.cloud.firestore_v1.types.write import WriteResult + +if TYPE_CHECKING: + from google.cloud.firestore_v1.base_client import BaseClient # pragma: NO COVER + + +logger = logging.getLogger(__name__) + + +class BulkRetry(enum.Enum): + """Indicator for what retry strategy the BulkWriter should use.""" + + # Common exponential backoff algorithm. This strategy is largely incompatible + # with the default retry limit of 15, so use with caution. + exponential = enum.auto() + + # Default strategy that adds 1 second of delay per retry. + linear = enum.auto() + + # Immediate retries with no growing delays. + immediate = enum.auto() + + +class SendMode(enum.Enum): + """Indicator for whether a BulkWriter should commit batches in the main + thread or hand that work off to an executor.""" + + # Default strategy that parallelizes network I/O on an executor. You almost + # certainly want this. + parallel = enum.auto() + + # Alternate strategy which blocks during all network I/O. Much slower, but + # assures all batches are sent to the server in order. Note that + # `SendMode.serial` is extremely susceptible to slowdowns from retries if + # there are a lot of errors. + serial = enum.auto() + + +class AsyncBulkWriterMixin: + """ + Mixin which contains the methods on `BulkWriter` which must only be + submitted to the executor (or called by functions submitted to the executor). + This mixin exists purely for organization and clarity of implementation + (e.g., there is no metaclass magic). + + The entrypoint to the parallelizable code path is `_send_batch()`, which is + wrapped in a decorator which ensures that the `SendMode` is honored. + """ + + def _with_send_mode(fn): + """Decorates a method to ensure it is only called via the executor + (IFF the SendMode value is SendMode.parallel!). + + Usage: + + @_with_send_mode + def my_method(self): + parallel_stuff() + + def something_else(self): + # Because of the decorator around `my_method`, the following + # method invocation: + self.my_method() + # becomes equivalent to `self._executor.submit(self.my_method)` + # when the send mode is `SendMode.parallel`. + + Use on entrypoint methods for code paths that *must* be parallelized. + """ + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + if self._send_mode == SendMode.parallel: + return self._executor.submit(lambda: fn(self, *args, **kwargs)) + else: + # For code parity, even `SendMode.serial` scenarios should return + # a future here. Anything else would badly complicate calling code. + result = fn(self, *args, **kwargs) + future = concurrent.futures.Future() + future.set_result(result) + return future + + return wrapper + + @_with_send_mode + def _send_batch( + self, batch: BulkWriteBatch, operations: List["BulkWriterOperation"] + ): + """Sends a batch without regard to rate limits, meaning limits must have + already been checked. To that end, do not call this directly; instead, + call `_send_until_queue_is_empty`. + + Args: + batch(:class:`~google.cloud.firestore_v1.base_batch.BulkWriteBatch`) + """ + _len_batch: int = len(batch) + self._in_flight_documents += _len_batch + response: BatchWriteResponse = self._send(batch) + self._in_flight_documents -= _len_batch + + # Update bookkeeping totals + self._total_batches_sent += 1 + self._total_write_operations += _len_batch + + self._process_response(batch, response, operations) + + def _process_response( + self, + batch: BulkWriteBatch, + response: BatchWriteResponse, + operations: List["BulkWriterOperation"], + ) -> None: + """Invokes submitted callbacks for each batch and each operation within + each batch. As this is called from `_send_batch()`, this is parallelized + if we are in that mode. + """ + batch_references: List[BaseDocumentReference] = list( + batch._document_references.values(), + ) + self._batch_callback(batch, response, self) + + status: status_pb2.Status + for index, status in enumerate(response.status): + if status.code == 0: + self._success_callback( + # DocumentReference + batch_references[index], + # WriteResult + response.write_results[index], + # BulkWriter + self, + ) + else: + operation: BulkWriterOperation = operations[index] + should_retry: bool = self._error_callback( + # BulkWriteFailure + BulkWriteFailure( + operation=operation, code=status.code, message=status.message, + ), + # BulkWriter + self, + ) + if should_retry: + operation.attempts += 1 + self._retry_operation(operation) + + def _retry_operation( + self, operation: "BulkWriterOperation", + ) -> concurrent.futures.Future: + + delay: int = 0 + if self._options.retry == BulkRetry.exponential: + delay = operation.attempts ** 2 # pragma: NO COVER + elif self._options.retry == BulkRetry.linear: + delay = operation.attempts + + run_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=delay) + + # Use of `bisect.insort` maintains the requirement that `self._retries` + # always remain sorted by each object's `run_at` time. Note that it is + # able to do this because `OperationRetry` instances are entirely sortable + # by their `run_at` value. + bisect.insort( + self._retries, OperationRetry(operation=operation, run_at=run_at), + ) + + def _send(self, batch: BulkWriteBatch) -> BatchWriteResponse: + """Hook for overwriting the sending of batches. As this is only called + from `_send_batch()`, this is parallelized if we are in that mode. + """ + return batch.commit() # pragma: NO COVER + + +class BulkWriter(AsyncBulkWriterMixin): + """ + Accumulate and efficiently save large amounts of document write operations + to the server. + + BulkWriter can handle large data migrations or updates, buffering records + in memory and submitting them to the server in batches of 20. + + The submission of batches is internally parallelized with a ThreadPoolExecutor, + meaning end developers do not need to manage an event loop or worry about asyncio + to see parallelization speed ups (which can easily 10x throughput). Because + of this, there is no companion `AsyncBulkWriter` class, as is usually seen + with other utility classes. + + Usage: + + # Instantiate the BulkWriter. This works from either `Client` or + # `AsyncClient`. + db = firestore.Client() + bulk_writer = db.bulk_writer() + + # Attach an optional success listener to be called once per document. + bulk_writer.on_write_result( + lambda reference, result, bulk_writer: print(f'Saved {reference._document_path}') + ) + + # Queue an arbitrary amount of write operations. + # Assume `my_new_records` is a list of (DocumentReference, dict,) + # tuple-pairs that you supply. + + reference: DocumentReference + data: dict + for reference, data in my_new_records: + bulk_writer.create(reference, data) + + # Block until all pooled writes are complete. + bulk_writer.flush() + + Args: + client(:class:`~google.cloud.firestore_v1.client.Client`): + The client that created this BulkWriter. + """ + + batch_size: int = 20 + + def __init__( + self, + client: Optional["BaseClient"] = None, + options: Optional["BulkWriterOptions"] = None, + ): + # Because `BulkWriter` instances are all synchronous/blocking on the + # main thread (instead using other threads for asynchrony), it is + # incompatible with AsyncClient's various methods that return Futures. + # `BulkWriter` parallelizes all of its network I/O without the developer + # having to worry about awaiting async methods, so we must convert an + # AsyncClient instance into a plain Client instance. + self._client = ( + client._to_sync_copy() if type(client).__name__ == "AsyncClient" else client + ) + self._options = options or BulkWriterOptions() + self._send_mode = self._options.mode + + self._operations: List[BulkWriterOperation] + # List of the `_document_path` attribute for each DocumentReference + # contained in the current `self._operations`. This is reset every time + # `self._operations` is reset. + self._operations_document_paths: List[BaseDocumentReference] + self._reset_operations() + + # List of all `BulkWriterOperation` objects that are waiting to be retried. + # Each such object is wrapped in an `OperationRetry` object which pairs + # the raw operation with the `datetime` of its next scheduled attempt. + # `self._retries` must always remain sorted for efficient reads, so it is + # required to only ever add elements via `bisect.insort`. + self._retries: collections.deque["OperationRetry"] = collections.deque([]) + + self._queued_batches = collections.deque([]) + self._is_open: bool = True + + # This list will go on to store the future returned from each submission + # to the executor, for the purpose of awaiting all of those futures' + # completions in the `flush` method. + self._pending_batch_futures: List[concurrent.futures.Future] = [] + + self._success_callback: Callable[ + [BaseDocumentReference, WriteResult, "BulkWriter"], None + ] = BulkWriter._default_on_success + self._batch_callback: Callable[ + [BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None + ] = BulkWriter._default_on_batch + self._error_callback: Callable[ + [BulkWriteFailure, BulkWriter], bool + ] = BulkWriter._default_on_error + + self._in_flight_documents: int = 0 + self._rate_limiter = RateLimiter( + initial_tokens=self._options.initial_ops_per_second, + global_max_tokens=self._options.max_ops_per_second, + ) + + # Keep track of progress as batches and write operations are completed + self._total_batches_sent: int = 0 + self._total_write_operations: int = 0 + + self._ensure_executor() + + @staticmethod + def _default_on_batch( + batch: BulkWriteBatch, response: BatchWriteResponse, bulk_writer: "BulkWriter", + ) -> None: + pass + + @staticmethod + def _default_on_success( + reference: BaseDocumentReference, + result: WriteResult, + bulk_writer: "BulkWriter", + ) -> None: + pass + + @staticmethod + def _default_on_error(error: "BulkWriteFailure", bulk_writer: "BulkWriter") -> bool: + # Default number of retries for each operation is 15. This is a scary + # number to combine with an exponential backoff, and as such, our default + # backoff strategy is linear instead of exponential. + return error.attempts < 15 + + def _reset_operations(self) -> None: + self._operations = [] + self._operations_document_paths = [] + + def _ensure_executor(self): + """Reboots the executor used to send batches if it has been shutdown.""" + if getattr(self, "_executor", None) is None or self._executor._shutdown: + self._executor = self._instantiate_executor() + + def _ensure_sending(self): + self._ensure_executor() + self._send_until_queue_is_empty() + + def _instantiate_executor(self): + return concurrent.futures.ThreadPoolExecutor() + + def flush(self): + """ + Block until all pooled write operations are complete and then resume + accepting new write operations. + """ + # Calling `flush` consecutively is a no-op. + if self._executor._shutdown: + return + + while True: + + # Queue any waiting operations and try our luck again. + # This can happen if users add a number of records not divisible by + # 20 and then call flush (which should be ~19 out of 20 use cases). + # Execution will arrive here and find the leftover operations that + # never filled up a batch organically, and so we must send them here. + if self._operations: + self._enqueue_current_batch() + continue + + # If we find queued but unsent batches or pending retries, begin + # sending immediately. Note that if we are waiting on retries, but + # they have longer to wait as specified by the retry backoff strategy, + # we may have to make several passes through this part of the loop. + # (This is related to the sleep and its explanation below.) + if self._queued_batches or self._retries: + self._ensure_sending() + + # This sleep prevents max-speed laps through this loop, which can + # and will happen if the BulkWriter is doing nothing except waiting + # on retries to be ready to re-send. Removing this sleep will cause + # whatever thread is running this code to sit near 100% CPU until + # all retries are abandoned or successfully resolved. + time.sleep(0.1) + continue + + # We store the executor's Future from each batch send operation, so + # the first pass through here, we are guaranteed to find "pending" + # batch futures and have to wait. However, the second pass through + # will be fast unless the last batch introduced more retries. + if self._pending_batch_futures: + _batches = self._pending_batch_futures + self._pending_batch_futures = [] + concurrent.futures.wait(_batches) + + # Continuing is critical here (as opposed to breaking) because + # the final batch may have introduced retries which is most + # straightforwardly verified by heading back to the top of the loop. + continue + + break + + # We no longer expect to have any queued batches or pending futures, + # so the executor can be shutdown. + self._executor.shutdown() + + def close(self): + """ + Block until all pooled write operations are complete and then reject + any further write operations. + """ + self._is_open = False + self.flush() + + def _maybe_enqueue_current_batch(self): + """ + Checks to see whether the in-progress batch is full and, if it is, + adds it to the sending queue. + """ + if len(self._operations) >= self.batch_size: + self._enqueue_current_batch() + + def _enqueue_current_batch(self): + """Adds the current batch to the back of the sending line, resets the + list of queued ops, and begins the process of actually sending whatever + batch is in the front of the line, which will often be a different batch. + """ + # Put our batch in the back of the sending line + self._queued_batches.append(self._operations) + + # Reset the local store of operations + self._reset_operations() + + # The sending loop powers off upon reaching the end of the queue, so + # here we make sure that is running. + self._ensure_sending() + + def _send_until_queue_is_empty(self): + """First domino in the sending codepath. This does not need to be + parallelized for two reasons: + + 1) Putting this on a worker thread could lead to two running in parallel + and thus unpredictable commit ordering or failure to adhere to + rate limits. + 2) This method only blocks when `self._request_send()` does not immediately + return, and in that case, the BulkWriter's ramp-up / throttling logic + has determined that it is attempting to exceed the maximum write speed, + and so parallelizing this method would not increase performance anyway. + + Once `self._request_send()` returns, this method calls `self._send_batch()`, + which parallelizes itself if that is our SendMode value. + + And once `self._send_batch()` is called (which does not block if we are + sending in parallel), jumps back to the top and re-checks for any queued + batches. + + Note that for sufficiently large data migrations, this can block the + submission of additional write operations (e.g., the CRUD methods); + but again, that is only if the maximum write speed is being exceeded, + and thus this scenario does not actually further reduce performance. + """ + self._schedule_ready_retries() + + while self._queued_batches: + + # For FIFO order, add to the right of this deque (via `append`) and take + # from the left (via `popleft`). + operations: List[BulkWriterOperation] = self._queued_batches.popleft() + + # Block until we are cleared for takeoff, which is fine because this + # returns instantly unless the rate limiting logic determines that we + # are attempting to exceed the maximum write speed. + self._request_send(len(operations)) + + # Handle some bookkeeping, and ultimately put these bits on the wire. + batch = BulkWriteBatch(client=self._client) + op: BulkWriterOperation + for op in operations: + op.add_to_batch(batch) + + # `_send_batch` is optionally parallelized by `@_with_send_mode`. + future = self._send_batch(batch=batch, operations=operations) + self._pending_batch_futures.append(future) + + self._schedule_ready_retries() + + def _schedule_ready_retries(self): + """Grabs all ready retries and re-queues them.""" + + # Because `self._retries` always exists in a sorted state (thanks to only + # ever adding to it via `bisect.insort`), and because `OperationRetry` + # objects are comparable against `datetime` objects, this bisect functionally + # returns the number of retires that are ready for immediate reenlistment. + take_until_index = bisect.bisect(self._retries, datetime.datetime.utcnow()) + + for _ in range(take_until_index): + retry: OperationRetry = self._retries.popleft() + retry.retry(self) + + def _request_send(self, batch_size: int) -> bool: + # Set up this boolean to avoid repeatedly taking tokens if we're only + # waiting on the `max_in_flight` limit. + have_received_tokens: bool = False + + while True: + # To avoid bottlenecks on the server, an additional limit is that no + # more write operations can be "in flight" (sent but still awaiting + # response) at any given point than the maximum number of writes per + # second. + under_threshold: bool = ( + self._in_flight_documents <= self._rate_limiter._maximum_tokens + ) + # Ask for tokens each pass through this loop until they are granted, + # and then stop. + have_received_tokens = ( + have_received_tokens or self._rate_limiter.take_tokens(batch_size) + ) + if not under_threshold or not have_received_tokens: + # Try again until both checks are true. + # Note that this sleep is helpful to prevent the main BulkWriter + # thread from spinning through this loop as fast as possible and + # pointlessly burning CPU while we wait for the arrival of a + # fixed moment in the future. + time.sleep(0.01) + continue + + return True + + def create( + self, reference: BaseDocumentReference, document_data: Dict, attempts: int = 0, + ) -> None: + """Adds a `create` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this create operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + document_data (dict): + Raw data to save to the server. + """ + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterCreateOperation( + reference=reference, document_data=document_data, attempts=attempts, + ), + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def delete( + self, + reference: BaseDocumentReference, + option: Optional[_helpers.WriteOption] = None, + attempts: int = 0, + ) -> None: + """Adds a `delete` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this delete operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + option (:class:`~google.cloud.firestore_v1._helpers.WriteOption`): + Optional flag to modify the nature of this write. + """ + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterDeleteOperation( + reference=reference, option=option, attempts=attempts, + ), + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def set( + self, + reference: BaseDocumentReference, + document_data: Dict, + merge: Union[bool, list] = False, + attempts: int = 0, + ) -> None: + """Adds a `set` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this set operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + document_data (dict): + Raw data to save to the server. + merge (bool): + Whether or not to completely overwrite any existing data with + the supplied data. + """ + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterSetOperation( + reference=reference, + document_data=document_data, + merge=merge, + attempts=attempts, + ) + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def update( + self, + reference: BaseDocumentReference, + field_updates: dict, + option: Optional[_helpers.WriteOption] = None, + attempts: int = 0, + ) -> None: + """Adds an `update` pb to the in-progress batch. + + If the in-progress batch already contains a write operation involving + this document reference, the batch will be sealed and added to the commit + queue, and a new batch will be created with this operation as its first + entry. + + If this update operation results in the in-progress batch reaching full + capacity, then the batch will be similarly added to the commit queue, and + a new batch will be created for future operations. + + Args: + reference (:class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`): + Pointer to the document that should be created. + field_updates (dict): + Key paths to specific nested data that should be upated. + option (:class:`~google.cloud.firestore_v1._helpers.WriteOption`): + Optional flag to modify the nature of this write. + """ + # This check is copied from other Firestore classes for the purposes of + # surfacing the error immediately. + if option.__class__.__name__ == "ExistsOption": + raise ValueError("you must not pass an explicit write option to update.") + + self._verify_not_closed() + + if reference._document_path in self._operations_document_paths: + self._enqueue_current_batch() + + self._operations.append( + BulkWriterUpdateOperation( + reference=reference, + field_updates=field_updates, + option=option, + attempts=attempts, + ) + ) + self._operations_document_paths.append(reference._document_path) + + self._maybe_enqueue_current_batch() + + def on_write_result( + self, + callback: Callable[[BaseDocumentReference, WriteResult, "BulkWriter"], None], + ) -> None: + """Sets a callback that will be invoked once for every successful operation.""" + self._success_callback = callback or BulkWriter._default_on_success + + def on_batch_result( + self, + callback: Callable[[BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None], + ) -> None: + """Sets a callback that will be invoked once for every successful batch.""" + self._batch_callback = callback or BulkWriter._default_on_batch + + def on_write_error( + self, callback: Callable[["BulkWriteFailure", "BulkWriter"], bool] + ) -> None: + """Sets a callback that will be invoked once for every batch that contains + an error.""" + self._error_callback = callback or BulkWriter._default_on_error + + def _verify_not_closed(self): + if not self._is_open: + raise Exception("BulkWriter is closed and cannot accept new operations") + + +class BulkWriterOperation: + """Parent class for all operation container classes. + + `BulkWriterOperation` exists to house all the necessary information for a + specific write task, including meta information like the current number of + attempts. If a write fails, it is its wrapper `BulkWriteOperation` class + that ferries it into its next retry without getting confused with other + similar writes to the same document. + """ + + def add_to_batch(self, batch: BulkWriteBatch): + """Adds `self` to the supplied batch.""" + assert isinstance(batch, BulkWriteBatch) + if isinstance(self, BulkWriterCreateOperation): + return batch.create( + reference=self.reference, document_data=self.document_data, + ) + + if isinstance(self, BulkWriterDeleteOperation): + return batch.delete(reference=self.reference, option=self.option,) + + if isinstance(self, BulkWriterSetOperation): + return batch.set( + reference=self.reference, + document_data=self.document_data, + merge=self.merge, + ) + + if isinstance(self, BulkWriterUpdateOperation): + return batch.update( + reference=self.reference, + field_updates=self.field_updates, + option=self.option, + ) + raise TypeError( + f"Unexpected type of {self.__class__.__name__} for batch" + ) # pragma: NO COVER + + +@functools.total_ordering +class BaseOperationRetry: + """Parent class for both the @dataclass and old-style `OperationRetry` + classes. + + Methods on this class be moved directly to `OperationRetry` when support for + Python 3.6 is dropped and `dataclasses` becomes universal. + """ + + def __lt__(self, other: "OperationRetry"): + """Allows use of `bisect` to maintain a sorted list of `OperationRetry` + instances, which in turn allows us to cheaply grab all that are ready to + run.""" + if isinstance(other, OperationRetry): + return self.run_at < other.run_at + elif isinstance(other, datetime.datetime): + return self.run_at < other + return NotImplemented # pragma: NO COVER + + def retry(self, bulk_writer: BulkWriter) -> None: + """Call this after waiting any necessary time to re-add the enclosed + operation to the supplied BulkWriter's internal queue.""" + if isinstance(self.operation, BulkWriterCreateOperation): + bulk_writer.create( + reference=self.operation.reference, + document_data=self.operation.document_data, + attempts=self.operation.attempts, + ) + + elif isinstance(self.operation, BulkWriterDeleteOperation): + bulk_writer.delete( + reference=self.operation.reference, + option=self.operation.option, + attempts=self.operation.attempts, + ) + + elif isinstance(self.operation, BulkWriterSetOperation): + bulk_writer.set( + reference=self.operation.reference, + document_data=self.operation.document_data, + merge=self.operation.merge, + attempts=self.operation.attempts, + ) + + elif isinstance(self.operation, BulkWriterUpdateOperation): + bulk_writer.update( + reference=self.operation.reference, + field_updates=self.operation.field_updates, + option=self.operation.option, + attempts=self.operation.attempts, + ) + else: + raise TypeError( + f"Unexpected type of {self.operation.__class__.__name__} for OperationRetry.retry" + ) # pragma: NO COVER + + +try: + from dataclasses import dataclass + + @dataclass + class BulkWriterOptions: + initial_ops_per_second: int = 500 + max_ops_per_second: int = 500 + mode: SendMode = SendMode.parallel + retry: BulkRetry = BulkRetry.linear + + @dataclass + class BulkWriteFailure: + operation: BulkWriterOperation + # https://grpc.github.io/grpc/core/md_doc_statuscodes.html + code: int + message: str + + @property + def attempts(self) -> int: + return self.operation.attempts + + @dataclass + class OperationRetry(BaseOperationRetry): + """Container for an additional attempt at an operation, scheduled for + the future.""" + + operation: BulkWriterOperation + run_at: datetime.datetime + + @dataclass + class BulkWriterCreateOperation(BulkWriterOperation): + """Container for BulkWriter.create() operations.""" + + reference: BaseDocumentReference + document_data: Dict + attempts: int = 0 + + @dataclass + class BulkWriterUpdateOperation(BulkWriterOperation): + """Container for BulkWriter.update() operations.""" + + reference: BaseDocumentReference + field_updates: Dict + option: Optional[_helpers.WriteOption] + attempts: int = 0 + + @dataclass + class BulkWriterSetOperation(BulkWriterOperation): + """Container for BulkWriter.set() operations.""" + + reference: BaseDocumentReference + document_data: Dict + merge: Union[bool, list] = False + attempts: int = 0 + + @dataclass + class BulkWriterDeleteOperation(BulkWriterOperation): + """Container for BulkWriter.delete() operations.""" + + reference: BaseDocumentReference + option: Optional[_helpers.WriteOption] + attempts: int = 0 + + +except ImportError: + + # Note: When support for Python 3.6 is dropped and `dataclasses` is reliably + # in the stdlib, this entire section can be dropped in favor of the dataclass + # versions above. Additonally, the methods on `BaseOperationRetry` can be added + # directly to `OperationRetry` and `BaseOperationRetry` can be deleted. + + class BulkWriterOptions: + def __init__( + self, + initial_ops_per_second: int = 500, + max_ops_per_second: int = 500, + mode: SendMode = SendMode.parallel, + retry: BulkRetry = BulkRetry.linear, + ): + self.initial_ops_per_second = initial_ops_per_second + self.max_ops_per_second = max_ops_per_second + self.mode = mode + self.retry = retry + + class BulkWriteFailure: + def __init__( + self, + operation: BulkWriterOperation, + # https://grpc.github.io/grpc/core/md_doc_statuscodes.html + code: int, + message: str, + ): + self.operation = operation + self.code = code + self.message = message + + @property + def attempts(self) -> int: + return self.operation.attempts + + class OperationRetry(BaseOperationRetry): + """Container for an additional attempt at an operation, scheduled for + the future.""" + + def __init__( + self, operation: BulkWriterOperation, run_at: datetime.datetime, + ): + self.operation = operation + self.run_at = run_at + + class BulkWriterCreateOperation(BulkWriterOperation): + """Container for BulkWriter.create() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + document_data: Dict, + attempts: int = 0, + ): + self.reference = reference + self.document_data = document_data + self.attempts = attempts + + class BulkWriterUpdateOperation(BulkWriterOperation): + """Container for BulkWriter.update() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + field_updates: Dict, + option: Optional[_helpers.WriteOption], + attempts: int = 0, + ): + self.reference = reference + self.field_updates = field_updates + self.option = option + self.attempts = attempts + + class BulkWriterSetOperation(BulkWriterOperation): + """Container for BulkWriter.set() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + document_data: Dict, + merge: Union[bool, list] = False, + attempts: int = 0, + ): + self.reference = reference + self.document_data = document_data + self.merge = merge + self.attempts = attempts + + class BulkWriterDeleteOperation(BulkWriterOperation): + """Container for BulkWriter.delete() operations.""" + + def __init__( + self, + reference: BaseDocumentReference, + option: Optional[_helpers.WriteOption], + attempts: int = 0, + ): + self.reference = reference + self.option = option + self.attempts = attempts diff --git a/google/cloud/firestore_v1/rate_limiter.py b/google/cloud/firestore_v1/rate_limiter.py new file mode 100644 index 000000000..ee920edae --- /dev/null +++ b/google/cloud/firestore_v1/rate_limiter.py @@ -0,0 +1,177 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from typing import NoReturn, Optional + + +def utcnow(): + return datetime.datetime.utcnow() + + +default_initial_tokens: int = 500 +default_phase_length: int = 60 * 5 # 5 minutes +microseconds_per_second: int = 1000000 + + +class RateLimiter: + """Implements 5/5/5 ramp-up via Token Bucket algorithm. + + 5/5/5 is a ramp up strategy that starts with a budget of 500 operations per + second. Additionally, every 5 minutes, the maximum budget can increase by + 50%. Thus, at 5:01 into a long bulk-writing process, the maximum budget + becomes 750 operations per second. At 10:01, the budget becomes 1,125 + operations per second. + + The Token Bucket algorithm uses the metaphor of a bucket, or pile, or really + any container, if we're being honest, of tokens from which a user is able + to draw. If there are tokens available, you can do the thing. If there are not, + you can not do the thing. Additionally, tokens replenish at a fixed rate. + + Usage: + + rate_limiter = RateLimiter() + tokens = rate_limiter.take_tokens(20) + + if not tokens: + queue_retry() + else: + for _ in range(tokens): + my_operation() + + Args: + initial_tokens (Optional[int]): Starting size of the budget. Defaults + to 500. + phase_length (Optional[int]): Number of seconds, after which, the size + of the budget can increase by 50%. Such an increase will happen every + [phase_length] seconds if operation requests continue consistently. + """ + + def __init__( + self, + initial_tokens: int = default_initial_tokens, + global_max_tokens: Optional[int] = None, + phase_length: int = default_phase_length, + ): + # Tracks the volume of operations during a given ramp-up phase. + self._operations_this_phase: int = 0 + + # If provided, this enforces a cap on the maximum number of writes per + # second we can ever attempt, regardless of how many 50% increases the + # 5/5/5 rule would grant. + self._global_max_tokens = global_max_tokens + + self._start: Optional[datetime.datetime] = None + self._last_refill: Optional[datetime.datetime] = None + + # Current number of available operations. Decrements with every + # permitted request and refills over time. + self._available_tokens: int = initial_tokens + + # Maximum size of the available operations. Can increase by 50% + # every [phase_length] number of seconds. + self._maximum_tokens: int = self._available_tokens + + if self._global_max_tokens is not None: + self._available_tokens = min( + self._available_tokens, self._global_max_tokens + ) + self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) + + # Number of seconds after which the [_maximum_tokens] can increase by 50%. + self._phase_length: int = phase_length + + # Tracks how many times the [_maximum_tokens] has increased by 50%. + self._phase: int = 0 + + def _start_clock(self): + self._start = self._start or utcnow() + self._last_refill = self._last_refill or utcnow() + + def take_tokens(self, num: Optional[int] = 1, allow_less: bool = False) -> int: + """Returns the number of available tokens, up to the amount requested.""" + self._start_clock() + self._check_phase() + self._refill() + + minimum_tokens = 1 if allow_less else num + + if self._available_tokens >= minimum_tokens: + _num_to_take = min(self._available_tokens, num) + self._available_tokens -= _num_to_take + self._operations_this_phase += _num_to_take + return _num_to_take + return 0 + + def _check_phase(self): + """Increments or decrements [_phase] depending on traffic. + + Every [_phase_length] seconds, if > 50% of available traffic was used + during the window, increases [_phase], otherwise, decreases [_phase]. + + This is a no-op unless a new [_phase_length] number of seconds since the + start was crossed since it was last called. + """ + age: datetime.timedelta = utcnow() - self._start + + # Uses integer division to calculate the expected phase. We start in + # Phase 0, so until [_phase_length] seconds have passed, this will + # not resolve to 1. + expected_phase: int = age.seconds // self._phase_length + + # Short-circuit if we are still in the expected phase. + if expected_phase == self._phase: + return + + operations_last_phase: int = self._operations_this_phase + self._operations_this_phase = 0 + + previous_phase: int = self._phase + self._phase = expected_phase + + # No-op if we did nothing for an entire phase + if operations_last_phase and self._phase > previous_phase: + self._increase_maximum_tokens() + + def _increase_maximum_tokens(self) -> NoReturn: + self._maximum_tokens = round(self._maximum_tokens * 1.5) + if self._global_max_tokens is not None: + self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) + + def _refill(self) -> NoReturn: + """Replenishes any tokens that should have regenerated since the last + operation.""" + now: datetime.datetime = utcnow() + time_since_last_refill: datetime.timedelta = now - self._last_refill + + if time_since_last_refill: + self._last_refill = now + + # If we haven't done anything for 1s, then we know for certain we + # should reset to max capacity. + if time_since_last_refill.seconds >= 1: + self._available_tokens = self._maximum_tokens + + # If we have done something in the last 1s, then we know we should + # allocate proportional tokens. + else: + _percent_of_max: float = ( + time_since_last_refill.microseconds / microseconds_per_second + ) + new_tokens: int = round(_percent_of_max * self._maximum_tokens) + + # Add the number of provisioned tokens, capped at the maximum size. + self._available_tokens = min( + self._maximum_tokens, self._available_tokens + new_tokens, + ) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 6e72e65cf..0975a73d0 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1075,6 +1075,29 @@ def test_batch(client, cleanup): assert not document3.get().exists +def test_live_bulk_writer(client, cleanup): + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + db: Client = client + bw: BulkWriter = db.bulk_writer() + col = db.collection(f"bulkitems{UNIQUE_RESOURCE_ID}") + + for index in range(50): + doc_ref = col.document(f"id-{index}") + bw.create(doc_ref, {"index": index}) + cleanup(doc_ref.delete) + + bw.close() + assert bw._total_batches_sent >= 3 # retries could lead to more than 3 batches + assert bw._total_write_operations >= 50 # same retries rule applies again + assert bw._in_flight_documents == 0 + assert len(bw._operations) == 0 + + # And now assert that the documents were in fact written to the database + assert len(col.get()) == 50 + + def test_watch_document(client, cleanup): db = client collection_ref = db.collection("wd-users" + UNIQUE_RESOURCE_ID) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index ef8022f0e..a4db4e75f 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -1026,6 +1026,29 @@ async def test_get_all(client, cleanup): check_snapshot(snapshot3, document3, restricted3, write_result3) +async def test_live_bulk_writer(client, cleanup): + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + db: AsyncClient = client + bw: BulkWriter = db.bulk_writer() + col = db.collection(f"bulkitems-async{UNIQUE_RESOURCE_ID}") + + for index in range(50): + doc_ref = col.document(f"id-{index}") + bw.create(doc_ref, {"index": index}) + cleanup(doc_ref.delete) + + bw.close() + assert bw._total_batches_sent >= 3 # retries could lead to more than 3 batches + assert bw._total_write_operations >= 50 # same retries rule applies again + assert bw._in_flight_documents == 0 + assert len(bw._operations) == 0 + + # And now assert that the documents were in fact written to the database + assert len(await col.get()) == 50 + + async def test_batch(client, cleanup): collection_name = "batch" + UNIQUE_RESOURCE_ID diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 65aece0d4..92d20b7ec 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import datetime import mock import typing @@ -82,3 +83,23 @@ def build_document_snapshot( create_time=create_time or build_timestamp(), update_time=update_time or build_timestamp(), ) + + +class FakeThreadPoolExecutor: + def __init__(self, *args, **kwargs): + self._shutdown = False + + def submit(self, callable) -> typing.NoReturn: + if self._shutdown: + raise RuntimeError( + "cannot schedule new futures after shutdown" + ) # pragma: NO COVER + future = concurrent.futures.Future() + future.set_result(callable()) + return future + + def shutdown(self): + self._shutdown = True + + def __repr__(self): + return f"FakeThreadPoolExecutor(shutdown={self._shutdown})" diff --git a/tests/unit/v1/test_async_batch.py b/tests/unit/v1/test_async_batch.py index dce1cefdf..39f0d5391 100644 --- a/tests/unit/v1/test_async_batch.py +++ b/tests/unit/v1/test_async_batch.py @@ -20,6 +20,8 @@ class TestAsyncWriteBatch(aiounittest.AsyncTestCase): + """Tests the AsyncWriteBatch.commit method""" + @staticmethod def _get_target_class(): from google.cloud.firestore_v1.async_batch import AsyncWriteBatch diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index b766c22fc..bb7a51dd8 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -373,6 +373,21 @@ async def test_get_all_unknown_result(self): metadata=client._rpc_metadata, ) + def test_bulk_writer(self): + """BulkWriter is opaquely async and thus does not have a dedicated + async variant.""" + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + client = self._make_default_one() + bulk_writer = client.bulk_writer() + self.assertIsInstance(bulk_writer, BulkWriter) + self.assertIs(bulk_writer._client, client._sync_copy) + + def test_sync_copy(self): + client = self._make_default_one() + # Multiple calls to this method should return the same cached instance. + self.assertIs(client._to_sync_copy(), client._to_sync_copy()) + def test_batch(self): from google.cloud.firestore_v1.async_batch import AsyncWriteBatch diff --git a/tests/unit/v1/test_base_batch.py b/tests/unit/v1/test_base_batch.py index affe0e139..6bdb0da07 100644 --- a/tests/unit/v1/test_base_batch.py +++ b/tests/unit/v1/test_base_batch.py @@ -13,16 +13,26 @@ # limitations under the License. import unittest +from google.cloud.firestore_v1.base_batch import BaseWriteBatch import mock +class TestableBaseWriteBatch(BaseWriteBatch): + def __init__(self, client): + super().__init__(client=client) + + """Create a fake subclass of `BaseWriteBatch` for the purposes of + evaluating the shared methods.""" + + def commit(self): + pass # pragma: NO COVER + + class TestBaseWriteBatch(unittest.TestCase): @staticmethod def _get_target_class(): - from google.cloud.firestore_v1.base_batch import BaseWriteBatch - - return BaseWriteBatch + return TestableBaseWriteBatch def _make_one(self, *args, **kwargs): klass = self._get_target_class() diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index 119942fc3..3e3bef1ad 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -18,6 +18,8 @@ class TestWriteBatch(unittest.TestCase): + """Tests the WriteBatch.commit method""" + @staticmethod def _get_target_class(): from google.cloud.firestore_v1.batch import WriteBatch @@ -61,6 +63,7 @@ def _commit_helper(self, retry=None, timeout=None): batch.create(document1, {"ten": 10, "buck": "ets"}) document2 = client.document("c", "d", "e", "f") batch.delete(document2) + self.assertEqual(len(batch), 2) write_pbs = batch._write_pbs[::] write_results = batch.commit(**kwargs) diff --git a/tests/unit/v1/test_bulk_batch.py b/tests/unit/v1/test_bulk_batch.py new file mode 100644 index 000000000..20d43b9cc --- /dev/null +++ b/tests/unit/v1/test_bulk_batch.py @@ -0,0 +1,105 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import mock + + +class TestBulkWriteBatch(unittest.TestCase): + """Tests the BulkWriteBatch.commit method""" + + @staticmethod + def _get_target_class(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + return BulkWriteBatch + + def _make_one(self, *args, **kwargs): + klass = self._get_target_class() + return klass(*args, **kwargs) + + def test_constructor(self): + batch = self._make_one(mock.sentinel.client) + self.assertIs(batch._client, mock.sentinel.client) + self.assertEqual(batch._write_pbs, []) + self.assertIsNone(batch.write_results) + + def _write_helper(self, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.Mock(spec=["batch_write"]) + write_response = firestore.BatchWriteResponse( + write_results=[write.WriteResult(), write.WriteResult()], + ) + firestore_api.batch_write.return_value = write_response + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Attach the fake GAPIC to a real client. + client = _make_client("grand") + client._firestore_api_internal = firestore_api + + # Actually make a batch with some mutations and call commit(). + batch = self._make_one(client) + document1 = client.document("a", "b") + self.assertFalse(document1 in batch) + batch.create(document1, {"ten": 10, "buck": "ets"}) + self.assertTrue(document1 in batch) + document2 = client.document("c", "d", "e", "f") + batch.delete(document2) + write_pbs = batch._write_pbs[::] + + resp = batch.commit(**kwargs) + self.assertEqual(resp.write_results, list(write_response.write_results)) + self.assertEqual(batch.write_results, resp.write_results) + # Make sure batch has no more "changes". + self.assertEqual(batch._write_pbs, []) + + # Verify the mocks. + firestore_api.batch_write.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "labels": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + def test_write(self): + self._write_helper() + + def test_write_w_retry_timeout(self): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + + self._write_helper(retry=retry, timeout=timeout) + + +def _make_credentials(): + import google.auth.credentials + + return mock.Mock(spec=google.auth.credentials.Credentials) + + +def _make_client(project="seventy-nine"): + from google.cloud.firestore_v1.client import Client + + credentials = _make_credentials() + return Client(project=project, credentials=credentials) diff --git a/tests/unit/v1/test_bulk_writer.py b/tests/unit/v1/test_bulk_writer.py new file mode 100644 index 000000000..685d48a52 --- /dev/null +++ b/tests/unit/v1/test_bulk_writer.py @@ -0,0 +1,600 @@ +# # Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest +from typing import List, NoReturn, Optional, Tuple, Type + +from google.rpc import status_pb2 +import aiounittest # type: ignore + +from google.cloud.firestore_v1._helpers import build_timestamp, ExistsOption +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.client import Client +from google.cloud.firestore_v1.base_client import BaseClient +from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch +from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriter, + BulkWriteFailure, + BulkWriterCreateOperation, + BulkWriterOptions, + BulkWriterOperation, + OperationRetry, + SendMode, +) +from google.cloud.firestore_v1.types.firestore import BatchWriteResponse +from google.cloud.firestore_v1.types.write import WriteResult +from tests.unit.v1._test_helpers import FakeThreadPoolExecutor + + +class NoSendBulkWriter(BulkWriter): + """Test-friendly BulkWriter subclass whose `_send` method returns faked + BatchWriteResponse instances and whose _process_response` method stores + those faked instances for later evaluation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._responses: List[ + Tuple[BulkWriteBatch, BatchWriteResponse, BulkWriterOperation] + ] = [] + self._fail_indices: List[int] = [] + + def _send(self, batch: BulkWriteBatch) -> BatchWriteResponse: + """Generate a fake `BatchWriteResponse` for the supplied batch instead + of actually submitting it to the server. + """ + return BatchWriteResponse( + write_results=[ + WriteResult(update_time=build_timestamp()) + if index not in self._fail_indices + else WriteResult() + for index, el in enumerate(batch._document_references.values()) + ], + status=[ + status_pb2.Status(code=0 if index not in self._fail_indices else 1) + for index, el in enumerate(batch._document_references.values()) + ], + ) + + def _process_response( + self, + batch: BulkWriteBatch, + response: BatchWriteResponse, + operations: List[BulkWriterOperation], + ) -> NoReturn: + super()._process_response(batch, response, operations) + self._responses.append((batch, response, operations)) + + def _instantiate_executor(self): + return FakeThreadPoolExecutor() + + +class _SyncClientMixin: + """Mixin which helps a `_BaseBulkWriterTests` subclass simulate usage of + synchronous Clients, Collections, DocumentReferences, etc.""" + + def _get_client_class(self) -> Type: + return Client + + +class _AsyncClientMixin: + """Mixin which helps a `_BaseBulkWriterTests` subclass simulate usage of + AsyncClients, AsyncCollections, AsyncDocumentReferences, etc.""" + + def _get_client_class(self) -> Type: + return AsyncClient + + +class _BaseBulkWriterTests: + def setUp(self): + self.client: BaseClient = self._get_client_class()() + + def _get_document_reference( + self, collection_name: Optional[str] = "col", id: Optional[str] = None, + ) -> Type: + return self.client.collection(collection_name).document(id) + + def _doc_iter(self, num: int, ids: Optional[List[str]] = None): + for _ in range(num): + id: Optional[str] = ids[_] if ids else None + yield self._get_document_reference(id=id), {"id": _} + + def _verify_bw_activity(self, bw: BulkWriter, counts: List[Tuple[int, int]]): + """ + Args: + bw: (BulkWriter) + The BulkWriter instance to inspect. + counts: (tuple) A sequence of integer pairs, with 0-index integers + representing the size of sent batches, and 1-index integers + representing the number of times batches of that size should + have been sent. + """ + total_batches = sum([el[1] for el in counts]) + batches_word = "batches" if total_batches != 1 else "batch" + self.assertEqual( + len(bw._responses), + total_batches, + f"Expected to have sent {total_batches} {batches_word}, but only sent {len(bw._responses)}", + ) + docs_count = {} + resp: BatchWriteResponse + for _, resp, ops in bw._responses: + docs_count.setdefault(len(resp.write_results), 0) + docs_count[len(resp.write_results)] += 1 + + self.assertEqual(len(docs_count), len(counts)) + for size, num_sent in counts: + self.assertEqual(docs_count[size], num_sent) + + # Assert flush leaves no operation behind + self.assertEqual(len(bw._operations), 0) + + def test_create_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(101): + bw.create(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_delete_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, _ in self._doc_iter(101): + bw.delete(ref) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_delete_separates_batch(self): + bw = NoSendBulkWriter(self.client) + ref = self._get_document_reference(id="asdf") + bw.create(ref, {}) + bw.delete(ref) + bw.flush() + # Consecutive batches each with 1 operation should have been sent + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_set_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(101): + bw.set(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_update_calls_send_correctly(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(101): + bw.update(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_update_separates_batch(self): + bw = NoSendBulkWriter(self.client) + ref = self._get_document_reference(id="asdf") + bw.create(ref, {}) + bw.update(ref, {"field": "value"}) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_invokes_success_callbacks_successfully(self): + bw = NoSendBulkWriter(self.client) + bw._fail_indices = [] + bw._sent_batches = 0 + bw._sent_documents = 0 + + def _on_batch(batch, response, bulk_writer): + assert isinstance(batch, BulkWriteBatch) + assert isinstance(response, BatchWriteResponse) + assert isinstance(bulk_writer, BulkWriter) + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + assert isinstance(ref, BaseDocumentReference) + assert isinstance(result, WriteResult) + assert isinstance(bulk_writer, BulkWriter) + bulk_writer._sent_documents += 1 + + bw.on_write_result(_on_write) + bw.on_batch_result(_on_batch) + + for ref, data in self._doc_iter(101): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._sent_batches, 6) + self.assertEqual(bw._sent_documents, 101) + self.assertEqual(len(bw._operations), 0) + + def test_invokes_error_callbacks_successfully(self): + bw = NoSendBulkWriter(self.client) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._sent_batches = 0 + bw._sent_documents = 0 + bw._total_retries = 0 + + times_to_retry = 1 + + def _on_batch(batch, response, bulk_writer): + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + bulk_writer._sent_documents += 1 # pragma: NO COVER + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_batch_result(_on_batch) + bw.on_write_result(_on_write) + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._sent_documents, 0) + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(bw._sent_batches, 2) + self.assertEqual(len(bw._operations), 0) + + def test_invokes_error_callbacks_successfully_multiple_retries(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._sent_batches = 0 + bw._sent_documents = 0 + bw._total_retries = 0 + + times_to_retry = 10 + + def _on_batch(batch, response, bulk_writer): + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + bulk_writer._sent_documents += 1 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_batch_result(_on_batch) + bw.on_write_result(_on_write) + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(2): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._sent_documents, 1) + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(bw._sent_batches, times_to_retry + 1) + self.assertEqual(len(bw._operations), 0) + + def test_default_error_handler(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + bw._attempts = 0 + + def _on_error(error, bw): + bw._attempts = error.attempts + return bw._default_on_error(error, bw) + + bw.on_write_error(_on_error) + + # First document in each batch will "fail" + bw._fail_indices = [0] + for ref, data in self._doc_iter(1): + bw.create(ref, data) + bw.flush() + self.assertEqual(bw._attempts, 15) + + def test_handles_errors_and_successes_correctly(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._sent_batches = 0 + bw._sent_documents = 0 + bw._total_retries = 0 + + times_to_retry = 1 + + def _on_batch(batch, response, bulk_writer): + bulk_writer._sent_batches += 1 + + def _on_write(ref, result, bulk_writer): + bulk_writer._sent_documents += 1 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_batch_result(_on_batch) + bw.on_write_result(_on_write) + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(40): + bw.create(ref, data) + bw.flush() + + # 19 successful writes per batch + self.assertEqual(bw._sent_documents, 38) + self.assertEqual(bw._total_retries, times_to_retry * 2) + self.assertEqual(bw._sent_batches, 4) + self.assertEqual(len(bw._operations), 0) + + def test_create_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.create(ref, data) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_delete_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, _ in self._doc_iter(1): + bw.delete(ref) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_set_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.set(ref, data) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_update_retriable(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(retry=BulkRetry.immediate), + ) + # First document in each batch will "fail" + bw._fail_indices = [0] + bw._total_retries = 0 + times_to_retry = 6 + + def _on_error(error, bw) -> bool: + assert isinstance(error, BulkWriteFailure) + should_retry = error.attempts < times_to_retry + if should_retry: + bw._total_retries += 1 + return should_retry + + bw.on_write_error(_on_error) + + for ref, data in self._doc_iter(1): + bw.update(ref, data) + bw.flush() + + self.assertEqual(bw._total_retries, times_to_retry) + self.assertEqual(len(bw._operations), 0) + + def test_serial_calls_send_correctly(self): + bw = NoSendBulkWriter( + self.client, options=BulkWriterOptions(mode=SendMode.serial) + ) + for ref, data in self._doc_iter(101): + bw.create(ref, data) + bw.flush() + # Full batches with 20 items should have been sent 5 times, and a 1-item + # batch should have been sent once. + self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) + + def test_separates_same_document(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(2, ["same-id", "same-id"]): + bw.create(ref, data) + bw.flush() + # Seeing the same document twice should lead to separate batches + # Expect to have sent 1-item batches twice. + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_separates_same_document_different_operation(self): + bw = NoSendBulkWriter(self.client) + for ref, data in self._doc_iter(1, ["same-id"]): + bw.create(ref, data) + bw.set(ref, data) + bw.flush() + # Seeing the same document twice should lead to separate batches. + # Expect to have sent 1-item batches twice. + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_ensure_sending_repeatedly_callable(self): + bw = NoSendBulkWriter(self.client) + bw._is_sending = True + bw._ensure_sending() + + def test_flush_close_repeatedly_callable(self): + bw = NoSendBulkWriter(self.client) + bw.flush() + bw.flush() + bw.close() + + def test_flush_sends_in_progress(self): + bw = NoSendBulkWriter(self.client) + bw.create(self._get_document_reference(), {"whatever": "you want"}) + bw.flush() + self._verify_bw_activity(bw, [(1, 1,)]) + + def test_flush_sends_all_queued_batches(self): + bw = NoSendBulkWriter(self.client) + for _ in range(2): + bw.create(self._get_document_reference(), {"whatever": "you want"}) + bw._queued_batches.append(bw._operations) + bw._reset_operations() + bw.flush() + self._verify_bw_activity(bw, [(1, 2,)]) + + def test_cannot_add_after_close(self): + bw = NoSendBulkWriter(self.client) + bw.close() + self.assertRaises(Exception, bw._verify_not_closed) + + def test_multiple_flushes(self): + bw = NoSendBulkWriter(self.client) + bw.flush() + bw.flush() + + def test_update_raises_with_bad_option(self): + bw = NoSendBulkWriter(self.client) + self.assertRaises( + ValueError, + bw.update, + self._get_document_reference("id"), + {}, + option=ExistsOption(exists=True), + ) + + +class TestSyncBulkWriter(_SyncClientMixin, _BaseBulkWriterTests, unittest.TestCase): + """All BulkWriters are opaquely async, but this one simulates a BulkWriter + dealing with synchronous DocumentReferences.""" + + +class TestAsyncBulkWriter( + _AsyncClientMixin, _BaseBulkWriterTests, aiounittest.AsyncTestCase +): + """All BulkWriters are opaquely async, but this one simulates a BulkWriter + dealing with AsyncDocumentReferences.""" + + +class TestScheduling(unittest.TestCase): + def test_max_in_flight_honored(self): + bw = NoSendBulkWriter(Client()) + # Calling this method sets up all the internal timekeeping machinery + bw._rate_limiter.take_tokens(20) + + # Now we pretend that all tokens have been consumed. This will force us + # to wait actual, real world milliseconds before being cleared to send more + bw._rate_limiter._available_tokens = 0 + + st = datetime.datetime.now() + + # Make a real request, subject to the actual real world clock. + # As this request is 1/10th the per second limit, we should wait ~100ms + bw._request_send(50) + + self.assertGreater( + datetime.datetime.now() - st, datetime.timedelta(milliseconds=90), + ) + + def test_operation_retry_scheduling(self): + now = datetime.datetime.now() + one_second_from_now = now + datetime.timedelta(seconds=1) + + db = Client() + operation = BulkWriterCreateOperation( + reference=db.collection("asdf").document("asdf"), + document_data={"does.not": "matter"}, + ) + operation2 = BulkWriterCreateOperation( + reference=db.collection("different").document("document"), + document_data={"different": "values"}, + ) + + op1 = OperationRetry(operation=operation, run_at=now) + op2 = OperationRetry(operation=operation2, run_at=now) + op3 = OperationRetry(operation=operation, run_at=one_second_from_now) + + self.assertLess(op1, op3) + self.assertLess(op1, op3.run_at) + self.assertLess(op2, op3) + self.assertLess(op2, op3.run_at) + + # Because these have the same values for `run_at`, neither should conclude + # they are less than the other. It is okay that if we checked them with + # greater-than evaluation, they would return True (because + # @functools.total_ordering flips the result from __lt__). In practice, + # this only arises for actual ties, and we don't care how actual ties are + # ordered as we maintain the sorted list of scheduled retries. + self.assertFalse(op1 < op2) + self.assertFalse(op2 < op1) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 0055dab2c..a46839ac5 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -369,6 +369,14 @@ def test_batch(self): self.assertIs(batch._client, client) self.assertEqual(batch._write_pbs, []) + def test_bulk_writer(self): + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + client = self._make_default_one() + bulk_writer = client.bulk_writer() + self.assertIsInstance(bulk_writer, BulkWriter) + self.assertIs(bulk_writer._client, client) + def test_transaction(self): from google.cloud.firestore_v1.transaction import Transaction diff --git a/tests/unit/v1/test_rate_limiter.py b/tests/unit/v1/test_rate_limiter.py new file mode 100644 index 000000000..ea41905e4 --- /dev/null +++ b/tests/unit/v1/test_rate_limiter.py @@ -0,0 +1,200 @@ +# Copyright 2021 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import unittest +from typing import Optional + +import mock +import google +from google.cloud.firestore_v1 import rate_limiter + + +# Pick a point in time as the center of our universe for this test run. +# It is okay for this to update every time the tests are run. +fake_now = datetime.datetime.utcnow() + + +def now_plus_n( + seconds: Optional[int] = 0, microseconds: Optional[int] = 0, +) -> datetime.timedelta: + return fake_now + datetime.timedelta(seconds=seconds, microseconds=microseconds,) + + +class TestRateLimiter(unittest.TestCase): + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_basic(self, mocked_now): + """Verifies that if the clock does not advance, the RateLimiter allows 500 + writes before crashing out. + """ + mocked_now.return_value = fake_now + # This RateLimiter will never advance. Poor fella. + ramp = rate_limiter.RateLimiter() + for _ in range(rate_limiter.default_initial_tokens): + self.assertEqual(ramp.take_tokens(), 1) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_with_refill(self, mocked_now): + """Verifies that if the clock advances, the RateLimiter allows appropriate + additional writes. + """ + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 0 + self.assertEqual(ramp.take_tokens(), 0) + # Advance the clock 0.1 seconds + mocked_now.return_value = now_plus_n(microseconds=100000) + for _ in range(round(rate_limiter.default_initial_tokens / 10)): + self.assertEqual(ramp.take_tokens(), 1) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_phase_length(self, mocked_now): + """Verifies that if the clock advances, the RateLimiter allows appropriate + additional writes. + """ + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + self.assertEqual(ramp.take_tokens(), 1) + ramp._available_tokens = 0 + self.assertEqual(ramp.take_tokens(), 0) + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + for _ in range(round(rate_limiter.default_initial_tokens * 3 / 2)): + self.assertTrue( + ramp.take_tokens(), msg=f"token {_} should have been allowed" + ) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_rate_limiter_idle_phase_length(self, mocked_now): + """Verifies that if the clock advances but nothing happens, the RateLimiter + doesn't ramp up. + """ + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 0 + self.assertEqual(ramp.take_tokens(), 0) + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + for _ in range(round(rate_limiter.default_initial_tokens)): + self.assertEqual( + ramp.take_tokens(), 1, msg=f"token {_} should have been allowed" + ) + self.assertEqual(ramp._maximum_tokens, 500) + self.assertEqual(ramp.take_tokens(), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_take_batch_size(self, mocked_now): + """Verifies that if the clock advances but nothing happens, the RateLimiter + doesn't ramp up. + """ + page_size: int = 20 + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 15 + self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 15) + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp._check_phase() + self.assertEqual(ramp._maximum_tokens, 750) + + for _ in range(740 // page_size): + self.assertEqual( + ramp.take_tokens(page_size), + page_size, + msg=f"page {_} should have been allowed", + ) + self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 10) + self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 0) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_phase_progress(self, mocked_now): + mocked_now.return_value = fake_now + + ramp = rate_limiter.RateLimiter() + self.assertEqual(ramp._phase, 0) + self.assertEqual(ramp._maximum_tokens, 500) + ramp.take_tokens() + + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 1) + self.assertEqual(ramp._maximum_tokens, 750) + + # Advance the clock another phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 1125) + + # Advance the clock another ms and the phase should not advance + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=2, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 1125) + + @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") + def test_global_max_tokens(self, mocked_now): + mocked_now.return_value = fake_now + + ramp = rate_limiter.RateLimiter(global_max_tokens=499,) + self.assertEqual(ramp._phase, 0) + self.assertEqual(ramp._maximum_tokens, 499) + ramp.take_tokens() + + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 1) + self.assertEqual(ramp._maximum_tokens, 499) + + # Advance the clock another phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=1, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 499) + + # Advance the clock another ms and the phase should not advance + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=2, + ) + ramp.take_tokens() + self.assertEqual(ramp._phase, 2) + self.assertEqual(ramp._maximum_tokens, 499) + + def test_utcnow(self): + self.assertTrue( + isinstance( + google.cloud.firestore_v1.rate_limiter.utcnow(), datetime.datetime, + ) + )