diff --git a/datacube/drivers/postgis/__init__.py b/datacube/drivers/postgis/__init__.py index 273ac73ba3..50348f828d 100644 --- a/datacube/drivers/postgis/__init__.py +++ b/datacube/drivers/postgis/__init__.py @@ -9,5 +9,6 @@ """ from ._connections import PostGisDb +from ._api import PostgisDbAPI -__all__ = ['PostGisDb'] +__all__ = ['PostGisDb', 'PostgisDbAPI'] diff --git a/datacube/drivers/postgis/_api.py b/datacube/drivers/postgis/_api.py index afc2b26edf..493231d271 100644 --- a/datacube/drivers/postgis/_api.py +++ b/datacube/drivers/postgis/_api.py @@ -182,6 +182,12 @@ def __init__(self, parentdb, connection): def in_transaction(self): return self._connection.in_transaction() + def begin(self): + self._connection.execute(text('BEGIN')) + + def commit(self): + self._connection.execute(text('COMMIT')) + def rollback(self): self._connection.execute(text('ROLLBACK')) diff --git a/datacube/drivers/postgis/_connections.py b/datacube/drivers/postgis/_connections.py index 6eb0fd8c90..7abbc604ca 100755 --- a/datacube/drivers/postgis/_connections.py +++ b/datacube/drivers/postgis/_connections.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from typing import Any, Callable, Iterable, Mapping, Optional, Union, Type -from sqlalchemy import event, create_engine, text +from sqlalchemy import event, create_engine from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL as EngineUrl # noqa: N811 @@ -245,7 +245,7 @@ def spatial_indexes(self, refresh=False) -> Iterable[CRS]: return list(self.spindexes.keys()) @contextmanager - def connect(self): + def _connect(self): """ Borrow a connection from the pool. @@ -258,35 +258,12 @@ def connect(self): The connection can raise errors if not following this advice ("server closed the connection unexpectedly"), as some servers will aggressively close idle connections (eg. DEA's NCI servers). It also prevents the connection from being reused while borrowed. - """ - with self._engine.connect() as connection: - yield _api.PostgisDbAPI(self, connection) - connection.close() - - @contextmanager - def begin(self): - """ - Start a transaction. - - Returns an instance that will maintain a single connection in a transaction. - - Call commit() or rollback() to complete the transaction or use a context manager: - - with db.begin() as trans: - trans.insert_dataset(...) - - (Don't share an instance between threads) - :rtype: PostgresDBAPI + Low level context manager, use ._db_connection instead """ with self._engine.connect() as connection: - connection.execute(text('BEGIN')) try: yield _api.PostgisDbAPI(self, connection) - connection.execute(text('COMMIT')) - except Exception: # pylint: disable=broad-except - connection.execute(text('ROLLBACK')) - raise finally: connection.close() diff --git a/datacube/drivers/postgres/__init__.py b/datacube/drivers/postgres/__init__.py index a573d53814..6fe10be3c7 100644 --- a/datacube/drivers/postgres/__init__.py +++ b/datacube/drivers/postgres/__init__.py @@ -9,5 +9,6 @@ """ from ._connections import PostgresDb +from ._api import PostgresDbAPI -__all__ = ['PostgresDb'] +__all__ = ['PostgresDb', 'PostgresDbAPI'] diff --git a/datacube/drivers/postgres/_api.py b/datacube/drivers/postgres/_api.py index 178b39faa9..8d02f24bd9 100644 --- a/datacube/drivers/postgres/_api.py +++ b/datacube/drivers/postgres/_api.py @@ -182,9 +182,15 @@ def __init__(self, connection): def in_transaction(self): return self._connection.in_transaction() + def begin(self): + self._connection.execute(text('BEGIN')) + def rollback(self): self._connection.execute(text('ROLLBACK')) + def commit(self): + self._connection.execute(text('COMMIT')) + def execute(self, command): return self._connection.execute(command) diff --git a/datacube/drivers/postgres/_connections.py b/datacube/drivers/postgres/_connections.py index 56db60bfe7..3fc5ed7de0 100755 --- a/datacube/drivers/postgres/_connections.py +++ b/datacube/drivers/postgres/_connections.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from typing import Callable, Optional, Union -from sqlalchemy import event, create_engine, text +from sqlalchemy import event, create_engine from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL as EngineUrl # noqa: N811 @@ -207,7 +207,7 @@ def init(self, with_permissions=True): return is_new @contextmanager - def connect(self): + def _connect(self): """ Borrow a connection from the pool. @@ -220,35 +220,12 @@ def connect(self): The connection can raise errors if not following this advice ("server closed the connection unexpectedly"), as some servers will aggressively close idle connections (eg. DEA's NCI servers). It also prevents the connection from being reused while borrowed. - """ - with self._engine.connect() as connection: - yield _api.PostgresDbAPI(connection) - connection.close() - - @contextmanager - def begin(self): - """ - Start a transaction. - - Returns an instance that will maintain a single connection in a transaction. - - Call commit() or rollback() to complete the transaction or use a context manager: - - with db.begin() as trans: - trans.insert_dataset(...) - - (Don't share an instance between threads) - :rtype: PostgresDBAPI + Low level context manager, use ._db_connection instead """ with self._engine.connect() as connection: - connection.execute(text('BEGIN')) try: yield _api.PostgresDbAPI(connection) - connection.execute(text('COMMIT')) - except Exception: # pylint: disable=broad-except - connection.execute(text('ROLLBACK')) - raise finally: connection.close() diff --git a/datacube/index/abstract.py b/datacube/index/abstract.py index 52636465ea..80bee46a1f 100644 --- a/datacube/index/abstract.py +++ b/datacube/index/abstract.py @@ -5,6 +5,7 @@ import datetime import logging from pathlib import Path +from threading import Lock from abc import ABC, abstractmethod from typing import (Any, Iterable, Iterator, @@ -13,11 +14,13 @@ from uuid import UUID from datacube.config import LocalConfig +from datacube.index.exceptions import TransactionException from datacube.index.fields import Field from datacube.model import Dataset, MetadataType, Range from datacube.model import DatasetType as Product from datacube.utils import cached_property, read_documents, InvalidDocException from datacube.utils.changes import AllowPolicy, Change, Offset +from datacube.utils.generic import thread_local_cache from datacube.utils.geometry import CRS, Geometry, box _LOG = logging.getLogger(__name__) @@ -988,6 +991,163 @@ def _extract_geom_from_query(self, q: Mapping[str, QueryField]) -> Optional[Geom return geom +class AbstractTransaction(ABC): + """ + Abstract base class for a Transaction Manager. All index implementations should extend this base class. + + Thread-local storage and locks ensures one active transaction per index per thread. + """ + + def __init__(self, index_id: str): + self._connection: Any = None + self._tls_id = f"txn-{index_id}" + self._obj_lock = Lock() + + # Main Transaction API + def begin(self) -> None: + """ + Start a new transaction. + + Raises an error if a transaction is already active for this thread. + + Calls implementation-specific _new_connection() method and manages thread local storage and locks. + """ + with self._obj_lock: + if self._connection is not None: + raise ValueError("Cannot start a new transaction as one is already active") + self._tls_stash() + + def commit(self) -> None: + """ + Commit the transaction. + + Raises an error if transaction is not active. + + Calls implementation-specific _commit() method, and manages thread local storage and locks. + """ + with self._obj_lock: + if self._connection is None: + raise ValueError("Cannot commit inactive transaction") + self._commit() + self._release_connection() + self._connection = None + self._tls_purge() + + def rollback(self) -> None: + """ + Rollback the transaction. + + Raises an error if transaction is not active. + + Calls implementation-specific _rollback() method, and manages thread local storage and locks. + """ + with self._obj_lock: + if self._connection is None: + raise ValueError("Cannot rollback inactive transaction") + self._rollback() + self._release_connection() + self._connection = None + self._tls_purge() + + @property + def active(self): + """ + :return: True if the transaction is active. + """ + return self._connection is not None + + # Manage thread-local storage + def _tls_stash(self) -> None: + """ + Check TLS is empty, create a new connection and stash it. + :return: + """ + stored_val = thread_local_cache(self._tls_id) + if stored_val is not None: + raise ValueError("Cannot start a new transaction as one is already active for this thread") + self._connection = self._new_connection() + thread_local_cache(self._tls_id, purge=True) + thread_local_cache(self._tls_id, self) + + def _tls_purge(self) -> None: + thread_local_cache(self._tls_id, purge=True) + + # Commit/Rollback exceptions for Context Manager usage patterns + def commit_exception(self, errmsg: str) -> TransactionException: + return TransactionException(errmsg, commit=True) + + def rollback_exception(self, errmsg: str) -> TransactionException: + return TransactionException(errmsg, commit=False) + + # Context Manager Interface + def __enter__(self): + self.begin() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if not self.active: + # User has already manually committed or rolled back. + return True + if exc_type is not None and issubclass(exc_type, TransactionException): + # User raised a TransactionException, Commit or rollback as per exception + if exc_value.commit: + self.commit() + else: + self.rollback() + # Tell runtime exception is caught and handled. + return True + elif exc_value is not None: + # Any other exception - rollback + self.rollback() + # Instruct runtime to rethrow exception + return False + else: + # Exited without exception - commit and continue + self.commit() + return True + + # Internal abstract methods for implementation-specific functionality + @abstractmethod + def _new_connection(self) -> Any: + """ + :return: a new index driver object representing a database connection or equivalent against which transactions + will be executed. + """ + + @abstractmethod + def _commit(self) -> None: + """ + Commit the transaction. + """ + + @abstractmethod + def _rollback(self) -> None: + """ + Rollback the transaction. + """ + + @abstractmethod + def _release_connection(self) -> None: + """ + Release the connection object stored in self._connection + """ + + +class UnhandledTransaction(AbstractTransaction): + # Minimal implementation for index drivers with no transaction handling. + def _new_connection(self) -> Any: + return True + + def _commit(self) -> None: + pass + + def _rollback(self) -> None: + pass + + def _release_connection(self) -> None: + pass + + class AbstractIndex(ABC): """ Abstract base class for an Index. All Index implementations should @@ -1004,31 +1164,33 @@ class AbstractIndex(ABC): # supports lineage supports_lineage = True supports_source_filters = True + # Supports ACID transactions + supports_transactions = False @property @abstractmethod def url(self) -> str: - ... + """A string representing the index""" @property @abstractmethod def users(self) -> AbstractUserResource: - ... + """A User Resource instance for the index""" @property @abstractmethod def metadata_types(self) -> AbstractMetadataTypeResource: - ... + """A MetadataType Resource instance for the index""" @property @abstractmethod def products(self) -> AbstractProductResource: - ... + """A Product Resource instance for the index""" @property @abstractmethod def datasets(self) -> AbstractDatasetResource: - ... + """A Dataset Resource instance for the index""" @classmethod @abstractmethod @@ -1037,24 +1199,46 @@ def from_config(cls, application_name: Optional[str] = None, validate_connection: bool = True ) -> "AbstractIndex": - ... + """Instantiate a new index from a LocalConfig object""" @classmethod @abstractmethod def get_dataset_fields(cls, doc: dict ) -> Mapping[str, Field]: - ... + """Return dataset search fields from a metadata type document""" @abstractmethod def init_db(self, with_default_types: bool = True, with_permissions: bool = True) -> bool: - ... + """ + Initialise an empty database. + + :param with_default_types: Whether to create default metadata types + :param with_permissions: Whether to create db permissions + :return: true if the database was created, false if already exists + """ @abstractmethod def close(self) -> None: - ... + """ + Close and cleanup the Index. + """ + + @property + @abstractmethod + def index_id(self) -> str: + """ + :return: Unique ID for this index + (e.g. same database/dataset storage + same index driver implementation = same id) + """ + + @abstractmethod + def transaction(self) -> AbstractTransaction: + """ + :return: a Transaction context manager for this index. + """ @abstractmethod def create_spatial_index(self, crs: CRS) -> bool: @@ -1066,6 +1250,12 @@ def create_spatial_index(self, crs: CRS) -> bool: None if spatial indexes are not supported. """ + def thread_transaction(self) -> Optional["AbstractTransaction"]: + """ + :return: The existing Transaction object cached in thread-local storage for this index, if there is one. + """ + return thread_local_cache(f"txn-{self.index_id}", None) + def spatial_indexes(self, refresh=False) -> Iterable[CRS]: """ Return a list of CRSs for which spatiotemporal indexes exist in the database. diff --git a/datacube/index/exceptions.py b/datacube/index/exceptions.py index 9e731c0b9f..e75a4135ae 100644 --- a/datacube/index/exceptions.py +++ b/datacube/index/exceptions.py @@ -14,3 +14,9 @@ class MissingRecordError(Exception): class IndexSetupError(Exception): pass + + +class TransactionException(Exception): # noqa: N818 + def __init__(self, *args, commit=False, **kwargs): + super().__init__(*args, **kwargs) + self.commit = commit diff --git a/datacube/index/memory/index.py b/datacube/index/memory/index.py index fd2045c04a..1a1385ffe3 100644 --- a/datacube/index/memory/index.py +++ b/datacube/index/memory/index.py @@ -3,19 +3,24 @@ # Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging +from threading import Lock from datacube.index.memory._datasets import DatasetResource # type: ignore from datacube.index.memory._fields import get_dataset_fields from datacube.index.memory._metadata_types import MetadataTypeResource from datacube.index.memory._products import ProductResource from datacube.index.memory._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, UnhandledTransaction from datacube.model import MetadataType from datacube.utils.geometry import CRS _LOG = logging.getLogger(__name__) +counter = 0 +counter_lock = Lock() + + class Index(AbstractIndex): """ Lightweight in-memory index driver @@ -26,6 +31,10 @@ def __init__(self) -> None: self._metadata_types = MetadataTypeResource() self._products = ProductResource(self.metadata_types) self._datasets = DatasetResource(self.products) + global counter + with counter_lock: + counter = counter + 1 + self._index_id = f"memory={counter}" @property def users(self) -> UserResource: @@ -47,6 +56,13 @@ def datasets(self) -> DatasetResource: def url(self) -> str: return "memory" + @property + def index_id(self) -> str: + return self._index_id + + def transaction(self) -> UnhandledTransaction: + return UnhandledTransaction(self.index_id) + @classmethod def from_config(cls, config, application_name=None, validate_connection=True): return cls() diff --git a/datacube/index/null/index.py b/datacube/index/null/index.py index 36a94b8783..a678f7998a 100644 --- a/datacube/index/null/index.py +++ b/datacube/index/null/index.py @@ -8,7 +8,7 @@ from datacube.index.null._metadata_types import MetadataTypeResource from datacube.index.null._products import ProductResource from datacube.index.null._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, UnhandledTransaction from datacube.model import MetadataType from datacube.model.fields import get_dataset_fields from datacube.utils.geometry import CRS @@ -49,6 +49,13 @@ def datasets(self) -> DatasetResource: def url(self) -> str: return "null" + @property + def index_id(self) -> str: + return "null" + + def transaction(self) -> UnhandledTransaction: + return UnhandledTransaction(self.index_id) + @classmethod def from_config(cls, config, application_name=None, validate_connection=True): return cls() diff --git a/datacube/index/postgis/_datasets.py b/datacube/index/postgis/_datasets.py index 7e43e7f210..e11646ade0 100755 --- a/datacube/index/postgis/_datasets.py +++ b/datacube/index/postgis/_datasets.py @@ -17,6 +17,7 @@ from datacube.drivers.postgis._fields import SimpleDocField, DateDocField from datacube.drivers.postgis._schema import Dataset as SQLDataset from datacube.index.abstract import AbstractDatasetResource, DatasetSpatialMixin, DSID +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.model import Dataset, Product from datacube.model.fields import Field from datacube.model.utils import flatten_datasets @@ -32,20 +33,21 @@ # pylint: disable=too-many-public-methods, too-many-lines -class DatasetResource(AbstractDatasetResource): +class DatasetResource(AbstractDatasetResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgis._connections.PostgresDb :type types: datacube.index._products.ProductResource """ - def __init__(self, db, product_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgis._connections.PostgresDb :type product_resource: datacube.index._products.ProductResource """ self._db = db - self.types = product_resource - self.products = product_resource + self._index = index + self.types = self._index.products # types is a compatibility alias for products. + self.products = self._index.products def get(self, id_: Union[str, UUID], include_sources=False): """ @@ -58,7 +60,7 @@ def get(self, id_: Union[str, UUID], include_sources=False): if isinstance(id_, str): id_ = UUID(id_) - with self._db.connect() as connection: + with self._db_connection() as connection: if not include_sources: dataset = connection.get_dataset(id_) return self._make(dataset, full_info=True) if dataset else None @@ -87,7 +89,7 @@ def to_uuid(x): ids = [to_uuid(i) for i in ids] - with self._db.connect() as connection: + with self._db_connection() as connection: rows = connection.get_datasets(ids) return [self._make(r, full_info=True) for r in rows] @@ -100,7 +102,7 @@ def get_derived(self, id_): """ if not isinstance(id_, UUID): id_ = UUID(id_) - with self._db.connect() as connection: + with self._db_connection() as connection: return [ self._make(result, full_info=True) for result in connection.get_derived_datasets(id_) @@ -113,7 +115,7 @@ def has(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: bool """ - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.contains_dataset(id_) def bulk_has(self, ids_): @@ -126,7 +128,7 @@ def bulk_has(self, ids_): :rtype: [bool] """ - with self._db.connect() as connection: + with self._db_connection() as connection: existing = set(connection.datasets_intersection(ids_)) return [x in existing for x in @@ -193,7 +195,7 @@ def process_bunch(dss, main_ds, transaction): dss = [dataset] - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: process_bunch(dss, dataset, transaction) return dataset @@ -218,7 +220,7 @@ def load_field(f: Union[str, fields.Field]) -> fields.Field: expressions = [product.metadata_type.dataset_fields.get('product') == product.name] - with self._db.connect() as connection: + with self._db_connection() as connection: for record in connection.get_duplicates(group_fields, expressions): dataset_ids = set(record[0]) grouped_fields = tuple(record[1:]) @@ -288,7 +290,7 @@ def update(self, dataset: Dataset, updates_allowed=None): _LOG.info("Updating dataset %s", dataset.id) product = self.types.get_by_name(dataset.type.name) - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: if not transaction.update_dataset(dataset.metadata_doc_without_lineage(), dataset.id, product.id): raise ValueError("Failed to update dataset %s..." % dataset.id) @@ -307,7 +309,7 @@ def insert_one(uri, transaction): # front of a stack for uri in new_uris[::-1]: if transaction is None: - with self._db.begin() as tr: + with self._db_connection(transaction=True) as tr: insert_one(uri, tr) else: insert_one(uri, transaction) @@ -318,7 +320,7 @@ def archive(self, ids): :param Iterable[UUID] ids: list of dataset ids to archive """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.archive_dataset(id_) @@ -328,7 +330,7 @@ def restore(self, ids): :param Iterable[UUID] ids: list of dataset ids to restore """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.restore_dataset(id_) @@ -338,7 +340,7 @@ def purge(self, ids: Iterable[DSID]): :param ids: iterable of dataset ids to purge """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.delete_dataset(id_) @@ -352,7 +354,7 @@ def get_all_dataset_ids(self, archived: bool): :param archived: :rtype: list[UUID] """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: return [dsid[0] for dsid in transaction.all_dataset_ids(archived)] def get_field_names(self, product_name=None): @@ -379,7 +381,7 @@ def get_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.get_locations(id_) def get_archived_locations(self, id_): @@ -389,7 +391,7 @@ def get_archived_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return [uri for uri, archived_dt in connection.get_archived_locations(id_)] def get_archived_location_times(self, id_): @@ -399,7 +401,7 @@ def get_archived_location_times(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: List[Tuple[str, datetime.datetime]] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return list(connection.get_archived_locations(id_)) def add_location(self, id_, uri): @@ -414,7 +416,7 @@ def add_location(self, id_, uri): warnings.warn("Cannot add empty uri. (dataset %s)" % id_) return False - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.insert_dataset_location(id_, uri) def get_datasets_for_location(self, uri, mode=None): @@ -425,7 +427,7 @@ def get_datasets_for_location(self, uri, mode=None): :param str mode: 'exact', 'prefix' or None (to guess) :return: """ - with self._db.connect() as connection: + with self._db_connection() as connection: return (self._make(row) for row in connection.get_datasets_for_location(uri, mode=mode)) def remove_location(self, id_, uri): @@ -436,7 +438,7 @@ def remove_location(self, id_, uri): :param str uri: fully qualified uri :returns bool: Was one removed? """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_removed = connection.remove_location(id_, uri) return was_removed @@ -448,7 +450,7 @@ def archive_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be archived """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_archived = connection.archive_location(id_, uri) return was_archived @@ -460,7 +462,7 @@ def restore_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be restored """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_restored = connection.restore_location(id_, uri) return was_restored @@ -501,7 +503,7 @@ def search_by_metadata(self, metadata): :param dict metadata: :rtype: list[Dataset] """ - with self._db.connect() as connection: + with self._db_connection() as connection: for dataset in self._make_many(connection.search_datasets_by_metadata(metadata)): yield dataset @@ -650,7 +652,7 @@ def _do_search_by_product(self, query, return_fields=False, select_field_names=N else: select_fields = tuple(dataset_fields[field_name] for field_name in select_field_names) - with self._db.connect() as connection: + with self._db_connection() as connection: yield (product, connection.search_datasets( query_exprs, @@ -666,7 +668,7 @@ def _do_count_by_product(self, query): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self._db_connection() as connection: count = connection.count_datasets(query_exprs) if count > 0: yield product, count @@ -691,7 +693,7 @@ def _do_time_count(self, period, query, ensure_single=False): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self._db_connection() as connection: yield product, list(connection.count_datasets_through_time( start, end, @@ -738,7 +740,7 @@ def get_product_time_bounds(self, product: str): offset=max_offset, selection='greatest') - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.execute( select( [func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)] @@ -792,7 +794,7 @@ class DatasetLight(result_type, DatasetSpatialMixin): class DatasetLight(result_type): # type: ignore __slots__ = () - with self._db.connect() as connection: + with self._db_connection() as connection: results = connection.search_unique_datasets( query_exprs, select_fields=select_fields, @@ -891,5 +893,5 @@ def get_custom_query_expressions(self, custom_query, custom_offsets): return custom_exprs def spatial_extent(self, ids: Iterable[DSID], crs: CRS = CRS("EPSG:4326")) -> Optional[Geometry]: - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.spatial_extent(ids, crs) diff --git a/datacube/index/postgis/_metadata_types.py b/datacube/index/postgis/_metadata_types.py index 94556b07d6..e01afdde87 100644 --- a/datacube/index/postgis/_metadata_types.py +++ b/datacube/index/postgis/_metadata_types.py @@ -7,6 +7,7 @@ from cachetools.func import lru_cache from datacube.index.abstract import AbstractMetadataTypeResource +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.model import MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes @@ -14,12 +15,13 @@ _LOG = logging.getLogger(__name__) -class MetadataTypeResource(AbstractMetadataTypeResource): - def __init__(self, db): +class MetadataTypeResource(AbstractMetadataTypeResource, IndexResourceAddIn): + def __init__(self, db, index): """ :type db: datacube.drivers.postgis._connections.PostgresDb """ self._db = db + self._index = index self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) @@ -67,7 +69,7 @@ def add(self, metadata_type, allow_table_lock=False): 'Metadata Type {}'.format(metadata_type.name) ) else: - with self._db.connect() as connection: + with self._db_connection() as connection: connection.insert_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -141,7 +143,7 @@ def update(self, metadata_type: MetadataType, allow_unsafe_updates=False, allow_ _LOG.info("Updating metadata type %s", metadata_type.name) - with self._db.connect() as connection: + with self._db_connection() as connection: connection.update_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -167,7 +169,7 @@ def update_document(self, definition, allow_unsafe_updates=False): # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type(id_) if record is None: raise KeyError('%s is not a valid MetadataType id') @@ -176,7 +178,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type_by_name(name) if not record: raise KeyError('%s is not a valid MetadataType name' % name) @@ -192,7 +194,7 @@ def check_field_indexes(self, allow_table_lock=False, If false, creation will be slightly slower and cannot be done in a transaction. """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.check_dynamic_fields( concurrently=not allow_table_lock, rebuild_indexes=rebuild_indexes, @@ -205,7 +207,7 @@ def get_all(self): :rtype: iter[datacube.model.MetadataType] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return self._make_many(connection.get_all_metadata_types()) def _make_many(self, query_rows): diff --git a/datacube/index/postgis/_products.py b/datacube/index/postgis/_products.py index 2f380258b7..e94d306e84 100644 --- a/datacube/index/postgis/_products.py +++ b/datacube/index/postgis/_products.py @@ -8,6 +8,7 @@ from datacube.index import fields from datacube.index.abstract import AbstractProductResource +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.model import Product, MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes @@ -17,19 +18,20 @@ _LOG = logging.getLogger(__name__) -class ProductResource(AbstractProductResource): +class ProductResource(AbstractProductResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgis._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ - def __init__(self, db, metadata_type_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgis._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ self._db = db - self.metadata_type_resource = metadata_type_resource + self._index = index + self.metadata_type_resource = self._index.metadata_types self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) @@ -74,7 +76,7 @@ def add(self, product, allow_table_lock=False): _LOG.warning('Adding metadata_type "%s" as it doesn\'t exist.', product.metadata_type.name) metadata_type = self.metadata_type_resource.add(product.metadata_type, allow_table_lock=allow_table_lock) - with self._db.connect() as connection: + with self._db_connection() as connection: connection.insert_product( name=product.name, metadata=product.metadata_doc, @@ -183,7 +185,7 @@ def update(self, product: Product, allow_unsafe_updates=False, allow_table_lock= metadata_type = self.metadata_type_resource.get_by_name(product.metadata_type.name) # TODO: should we add metadata type here? assert metadata_type, "TODO: should we add metadata type here?" - with self._db.connect() as conn: + with self._db_connection() as conn: conn.update_product( name=product.name, metadata=product.metadata_doc, @@ -221,7 +223,7 @@ def update_document(self, definition, allow_unsafe_updates=False, allow_table_lo # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.get_product(id_) if not result: raise KeyError('"%s" is not a valid Product id' % id_) @@ -230,7 +232,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.get_product_by_name(name) if not result: raise KeyError('"%s" is not a valid Product name' % name) @@ -305,7 +307,7 @@ def get_all(self) -> Iterable[Product]: """ Retrieve all Products """ - with self._db.connect() as connection: + with self._db_connection() as connection: return (self._make(record) for record in connection.get_all_products()) def _make_many(self, query_rows): diff --git a/datacube/index/postgis/_transaction.py b/datacube/index/postgis/_transaction.py new file mode 100644 index 0000000000..60994be7ec --- /dev/null +++ b/datacube/index/postgis/_transaction.py @@ -0,0 +1,62 @@ +# This file is part of the Open Data Cube, see https://opendatacube.org for more information +# +# Copyright (c) 2015-2022 ODC Contributors +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from sqlalchemy import text +from typing import Any + +from datacube.drivers.postgis import PostGisDb +from datacube.drivers.postgis._api import PostgisDbAPI +from datacube.index.abstract import AbstractTransaction + + +class PostgisTransaction(AbstractTransaction): + def __init__(self, db: PostGisDb, idx_id: str) -> None: + super().__init__(idx_id) + self._db = db + + def _new_connection(self) -> Any: + dbconn = self._db.give_me_a_connection() + dbconn.execute(text('BEGIN')) + conn = PostgisDbAPI(self._db, dbconn) + return conn + + def _commit(self) -> None: + self._connection.commit() + + def _rollback(self) -> None: + self._connection.rollback() + + def _release_connection(self) -> None: + self._connection._connection.close() + self._connection._connection = None + + +class IndexResourceAddIn: + @contextmanager + def _db_connection(self, transaction: bool = False) -> PostgisDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + In Resource Manager code replace self._db.connect() with self.db_connection(), and replace + self._db.begin() with self.db_connection(transaction=True). + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + with self._index._active_connection(transaction=transaction) as conn: + yield conn diff --git a/datacube/index/postgis/_users.py b/datacube/index/postgis/_users.py index 85a2ac9e34..c604a676b4 100644 --- a/datacube/index/postgis/_users.py +++ b/datacube/index/postgis/_users.py @@ -4,21 +4,26 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Iterable, Optional, Tuple from datacube.index.abstract import AbstractUserResource +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.drivers.postgis import PostGisDb -class UserResource(AbstractUserResource): - def __init__(self, db: PostGisDb) -> None: +class UserResource(AbstractUserResource, IndexResourceAddIn): + def __init__(self, + db: PostGisDb, + index: "datacube.index.postgis.index.Index" # noqa: F821 + ) -> None: """ :type db: datacube.drivers.postgis.PostGisDb """ self._db = db + self._index = index def grant_role(self, role: str, *usernames: str) -> None: """ Grant a role to users """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.grant_role(role, usernames) def create_user(self, username: str, password: str, @@ -26,14 +31,14 @@ def create_user(self, username: str, password: str, """ Create a new user. """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.create_user(username, password, role, description=description) def delete_user(self, *usernames: str) -> None: """ Delete a user """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.drop_users(usernames) def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: @@ -41,6 +46,6 @@ def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: :return: list of (role, user, description) :rtype: list[(str, str, str)] """ - with self._db.connect() as connection: + with self._db_connection() as connection: for role, user, description in connection.list_users(): yield role, user, description diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index 6b03f4bcf5..db8366160a 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -3,14 +3,16 @@ # Copyright (c) 2015-2020 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging +from contextlib import contextmanager from typing import Iterable, Sequence -from datacube.drivers.postgis import PostGisDb +from datacube.drivers.postgis import PostGisDb, PostgisDbAPI +from datacube.index.postgis._transaction import PostgisTransaction from datacube.index.postgis._datasets import DatasetResource, DSID # type: ignore from datacube.index.postgis._metadata_types import MetadataTypeResource from datacube.index.postgis._products import ProductResource from datacube.index.postgis._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs, AbstractTransaction from datacube.model import MetadataType from datacube.utils.geometry import CRS @@ -48,6 +50,7 @@ class Index(AbstractIndex): # Hopefully can reinstate a simpler form of lineage support, but leave for now supports_lineage = False supports_source_filters = False + supports_transactions = True def __init__(self, db: PostGisDb) -> None: # POSTGIS driver is not stable with respect to database schema or internal APIs. @@ -56,10 +59,10 @@ def __init__(self, db: PostGisDb) -> None: WARNING: Database schema and internal APIs may change significantly between releases. Use at your own risk.""") self._db = db - self._users = UserResource(db) - self._metadata_types = MetadataTypeResource(db) - self._products = ProductResource(db, self.metadata_types) - self._datasets = DatasetResource(db, self.products) + self._users = UserResource(db, self) + self._metadata_types = MetadataTypeResource(db, self) + self._products = ProductResource(db, self) + self._datasets = DatasetResource(db, self) @property def users(self) -> UserResource: @@ -115,6 +118,13 @@ def close(self): """ self._db.close() + @property + def index_id(self) -> str: + return self.url + + def transaction(self) -> AbstractTransaction: + return PostgisTransaction(self._db, self.index_id) + def create_spatial_index(self, crs: CRS) -> bool: sp_idx = self._db.create_spatial_index(crs) return sp_idx is not None @@ -127,12 +137,53 @@ def update_spatial_index(self, product_names: Sequence[str] = [], dataset_ids: Sequence[DSID] = [] ) -> int: - with self._db.connect() as conn: + with self._active_connection(transaction=True) as conn: return conn.update_spindex(crses, product_names, dataset_ids) def __repr__(self): return "Index".format(self._db) + @contextmanager + def _active_connection(self, transaction: bool = False) -> PostgisDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + trans = self.thread_transaction() + closing = False + if trans is not None: + # Use active transaction + yield trans._connection + elif transaction: + closing = True + with self._db._connect() as conn: + conn.begin() + try: + yield conn + conn.commit() + except Exception: # pylint: disable=broad-except + conn.rollback() + raise + else: + closing = True + # Autocommit behaviour: + with self._db._connect() as conn: + yield conn + class DefaultIndexDriver(AbstractIndexDriver): @staticmethod diff --git a/datacube/index/postgres/_datasets.py b/datacube/index/postgres/_datasets.py index 127a241dc3..dbb198dfa1 100755 --- a/datacube/index/postgres/_datasets.py +++ b/datacube/index/postgres/_datasets.py @@ -1,6 +1,6 @@ # This file is part of the Open Data Cube, see https://opendatacube.org for more information # -# Copyright (c) 2015-2020 ODC Contributors +# Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 """ API for dataset indexing, access and search. @@ -17,6 +17,7 @@ from datacube.drivers.postgres._fields import SimpleDocField, DateDocField from datacube.drivers.postgres._schema import DATASET from datacube.index.abstract import AbstractDatasetResource, DatasetSpatialMixin, DSID +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.model import Dataset, DatasetType from datacube.model.fields import Field from datacube.model.utils import flatten_datasets @@ -29,19 +30,20 @@ # It's a public api, so we can't reorganise old methods. # pylint: disable=too-many-public-methods, too-many-lines -class DatasetResource(AbstractDatasetResource): +class DatasetResource(AbstractDatasetResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgres._connections.PostgresDb :type types: datacube.index._products.ProductResource """ - def __init__(self, db, dataset_type_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgres._connections.PostgresDb :type dataset_type_resource: datacube.index._products.ProductResource """ self._db = db - self.types = dataset_type_resource + self._index = index + self.types = self._index.products def get(self, id_: Union[str, UUID], include_sources=False): """ @@ -54,7 +56,7 @@ def get(self, id_: Union[str, UUID], include_sources=False): if isinstance(id_, str): id_ = UUID(id_) - with self._db.connect() as connection: + with self._db_connection() as connection: if not include_sources: dataset = connection.get_dataset(id_) return self._make(dataset, full_info=True) if dataset else None @@ -83,7 +85,7 @@ def to_uuid(x): ids = [to_uuid(i) for i in ids] - with self._db.connect() as connection: + with self._db_connection() as connection: rows = connection.get_datasets(ids) return [self._make(r, full_info=True) for r in rows] @@ -96,7 +98,7 @@ def get_derived(self, id_): """ if not isinstance(id_, UUID): id_ = UUID(id_) - with self._db.connect() as connection: + with self._db_connection() as connection: return [ self._make(result, full_info=True) for result in connection.get_derived_datasets(id_) @@ -109,7 +111,7 @@ def has(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: bool """ - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.contains_dataset(id_) def bulk_has(self, ids_): @@ -122,7 +124,7 @@ def bulk_has(self, ids_): :rtype: [bool] """ - with self._db.connect() as connection: + with self._db_connection() as connection: existing = set(connection.datasets_intersection(ids_)) return [x in existing for x in @@ -182,7 +184,7 @@ def process_bunch(dss, main_ds, transaction): dss = [dataset] - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: process_bunch(dss, dataset, transaction) return dataset @@ -207,7 +209,7 @@ def load_field(f: Union[str, fields.Field]) -> fields.Field: expressions = [product.metadata_type.dataset_fields.get('product') == product.name] - with self._db.connect() as connection: + with self._db_connection() as connection: for record in connection.get_duplicates(group_fields, expressions): dataset_ids = set(record[0]) grouped_fields = tuple(record[1:]) @@ -275,7 +277,7 @@ def update(self, dataset: Dataset, updates_allowed=None): _LOG.info("Updating dataset %s", dataset.id) product = self.types.get_by_name(dataset.type.name) - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: if not transaction.update_dataset(dataset.metadata_doc_without_lineage(), dataset.id, product.id): raise ValueError("Failed to update dataset %s..." % dataset.id) @@ -294,7 +296,7 @@ def insert_one(uri, transaction): # front of a stack for uri in new_uris[::-1]: if transaction is None: - with self._db.begin() as tr: + with self._db_connection(transaction=True) as tr: insert_one(uri, tr) else: insert_one(uri, transaction) @@ -305,7 +307,7 @@ def archive(self, ids): :param Iterable[UUID] ids: list of dataset ids to archive """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.archive_dataset(id_) @@ -315,7 +317,7 @@ def restore(self, ids): :param Iterable[UUID] ids: list of dataset ids to restore """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.restore_dataset(id_) @@ -325,7 +327,7 @@ def purge(self, ids: Iterable[DSID]): :param ids: iterable of dataset ids to purge """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.delete_dataset(id_) @@ -339,7 +341,7 @@ def get_all_dataset_ids(self, archived: bool): :param archived: :rtype: list[UUID] """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: return [dsid[0] for dsid in transaction.all_dataset_ids(archived)] def get_field_names(self, product_name=None): @@ -366,7 +368,7 @@ def get_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.get_locations(id_) def get_archived_locations(self, id_): @@ -376,7 +378,7 @@ def get_archived_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return [uri for uri, archived_dt in connection.get_archived_locations(id_)] def get_archived_location_times(self, id_): @@ -386,7 +388,7 @@ def get_archived_location_times(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: List[Tuple[str, datetime.datetime]] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return list(connection.get_archived_locations(id_)) def add_location(self, id_, uri): @@ -401,7 +403,7 @@ def add_location(self, id_, uri): warnings.warn("Cannot add empty uri. (dataset %s)" % id_) return False - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.insert_dataset_location(id_, uri) def get_datasets_for_location(self, uri, mode=None): @@ -412,7 +414,7 @@ def get_datasets_for_location(self, uri, mode=None): :param str mode: 'exact', 'prefix' or None (to guess) :return: """ - with self._db.connect() as connection: + with self._db_connection() as connection: return (self._make(row) for row in connection.get_datasets_for_location(uri, mode=mode)) def remove_location(self, id_, uri): @@ -423,7 +425,7 @@ def remove_location(self, id_, uri): :param str uri: fully qualified uri :returns bool: Was one removed? """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_removed = connection.remove_location(id_, uri) return was_removed @@ -435,7 +437,7 @@ def archive_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be archived """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_archived = connection.archive_location(id_, uri) return was_archived @@ -447,7 +449,7 @@ def restore_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be restored """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_restored = connection.restore_location(id_, uri) return was_restored @@ -488,7 +490,7 @@ def search_by_metadata(self, metadata): :param dict metadata: :rtype: list[Dataset] """ - with self._db.connect() as connection: + with self._db_connection() as connection: for dataset in self._make_many(connection.search_datasets_by_metadata(metadata)): yield dataset @@ -644,7 +646,7 @@ def _do_search_by_product(self, query, return_fields=False, select_field_names=N else: select_fields = tuple(dataset_fields[field_name] for field_name in select_field_names) - with self._db.connect() as connection: + with self._db_connection() as connection: yield (product, connection.search_datasets( query_exprs, @@ -660,7 +662,7 @@ def _do_count_by_product(self, query): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self._db_connection() as connection: count = connection.count_datasets(query_exprs) if count > 0: yield product, count @@ -685,7 +687,7 @@ def _do_time_count(self, period, query, ensure_single=False): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self._db_connection() as connection: yield product, list(connection.count_datasets_through_time( start, end, @@ -730,7 +732,7 @@ def get_product_time_bounds(self, product: str): offset=max_offset, selection='greatest') - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.execute( select( [func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)] @@ -784,7 +786,7 @@ class DatasetLight(result_type, DatasetSpatialMixin): class DatasetLight(result_type): # type: ignore __slots__ = () - with self._db.connect() as connection: + with self._db_connection() as connection: results = connection.search_unique_datasets( query_exprs, select_fields=select_fields, diff --git a/datacube/index/postgres/_metadata_types.py b/datacube/index/postgres/_metadata_types.py index dbf4045b51..feef1935fe 100644 --- a/datacube/index/postgres/_metadata_types.py +++ b/datacube/index/postgres/_metadata_types.py @@ -7,6 +7,7 @@ from cachetools.func import lru_cache from datacube.index.abstract import AbstractMetadataTypeResource +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.model import MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes @@ -14,12 +15,13 @@ _LOG = logging.getLogger(__name__) -class MetadataTypeResource(AbstractMetadataTypeResource): - def __init__(self, db): +class MetadataTypeResource(AbstractMetadataTypeResource, IndexResourceAddIn): + def __init__(self, db, index): """ :type db: datacube.drivers.postgres._connections.PostgresDb """ self._db = db + self._index = index self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) @@ -51,7 +53,8 @@ def add(self, metadata_type, allow_table_lock=False): Allow an exclusive lock to be taken on the table while creating the indexes. This will halt other user's requests until completed. - If false, creation will be slightly slower and cannot be done in a transaction. + If false (and a transaction is not already active), creation will be slightly slower + and cannot be done in a transaction. :rtype: datacube.model.MetadataType """ # This column duplication is getting out of hand: @@ -67,7 +70,7 @@ def add(self, metadata_type, allow_table_lock=False): 'Metadata Type {}'.format(metadata_type.name) ) else: - with self._db.connect() as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.insert_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -141,7 +144,7 @@ def update(self, metadata_type: MetadataType, allow_unsafe_updates=False, allow_ _LOG.info("Updating metadata type %s", metadata_type.name) - with self._db.connect() as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.update_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -167,7 +170,7 @@ def update_document(self, definition, allow_unsafe_updates=False): # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type(id_) if record is None: raise KeyError('%s is not a valid MetadataType id') @@ -176,7 +179,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type_by_name(name) if not record: raise KeyError('%s is not a valid MetadataType name' % name) @@ -192,7 +195,7 @@ def check_field_indexes(self, allow_table_lock=False, If false, creation will be slightly slower and cannot be done in a transaction. """ - with self._db.connect() as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.check_dynamic_fields( concurrently=not allow_table_lock, rebuild_indexes=rebuild_indexes, @@ -205,7 +208,7 @@ def get_all(self): :rtype: iter[datacube.model.MetadataType] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return self._make_many(connection.get_all_metadata_types()) def _make_many(self, query_rows): diff --git a/datacube/index/postgres/_products.py b/datacube/index/postgres/_products.py index 4bdd34635c..d0dff2be40 100644 --- a/datacube/index/postgres/_products.py +++ b/datacube/index/postgres/_products.py @@ -8,28 +8,31 @@ from datacube.index import fields from datacube.index.abstract import AbstractProductResource +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.model import DatasetType, MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes from typing import Iterable, cast + _LOG = logging.getLogger(__name__) -class ProductResource(AbstractProductResource): +class ProductResource(AbstractProductResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgres._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ - def __init__(self, db, metadata_type_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgres._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ self._db = db - self.metadata_type_resource = metadata_type_resource + self._index = index + self.metadata_type_resource = self._index.metadata_types self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) @@ -54,7 +57,8 @@ def add(self, product, allow_table_lock=False): Allow an exclusive lock to be taken on the table while creating the indexes. This will halt other user's requests until completed. - If false, creation will be slightly slower and cannot be done in a transaction. + If false (and there is no already active transaction), creation will be slightly slower + and cannot be done in a transaction. :param DatasetType product: Product to add :rtype: DatasetType """ @@ -74,7 +78,7 @@ def add(self, product, allow_table_lock=False): _LOG.warning('Adding metadata_type "%s" as it doesn\'t exist.', product.metadata_type.name) metadata_type = self.metadata_type_resource.add(product.metadata_type, allow_table_lock=allow_table_lock) - with self._db.connect() as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.insert_product( name=product.name, metadata=product.metadata_doc, @@ -186,7 +190,7 @@ def update(self, product: DatasetType, allow_unsafe_updates=False, allow_table_l metadata_type = cast(MetadataType, self.metadata_type_resource.get_by_name(product.metadata_type.name)) # Given we cannot change metadata type because of the check above, and this is an # update method, the metadata type is guaranteed to already exist. - with self._db.connect() as conn: + with self._db_connection(transaction=allow_table_lock) as conn: conn.update_product( name=product.name, metadata=product.metadata_doc, @@ -224,7 +228,7 @@ def update_document(self, definition, allow_unsafe_updates=False, allow_table_lo # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.get_product(id_) if not result: raise KeyError('"%s" is not a valid Product id' % id_) @@ -233,7 +237,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.get_product_by_name(name) if not result: raise KeyError('"%s" is not a valid Product name' % name) @@ -305,7 +309,7 @@ def get_all(self) -> Iterable[DatasetType]: """ Retrieve all Products """ - with self._db.connect() as connection: + with self._db_connection() as connection: return (self._make(record) for record in connection.get_all_products()) def _make_many(self, query_rows): diff --git a/datacube/index/postgres/_transaction.py b/datacube/index/postgres/_transaction.py new file mode 100644 index 0000000000..2c9a7a17e7 --- /dev/null +++ b/datacube/index/postgres/_transaction.py @@ -0,0 +1,62 @@ +# This file is part of the Open Data Cube, see https://opendatacube.org for more information +# +# Copyright (c) 2015-2022 ODC Contributors +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from sqlalchemy import text +from typing import Any + +from datacube.drivers.postgres import PostgresDb +from datacube.drivers.postgres._api import PostgresDbAPI +from datacube.index.abstract import AbstractTransaction + + +class PostgresTransaction(AbstractTransaction): + def __init__(self, db: PostgresDb, idx_id: str) -> None: + super().__init__(idx_id) + self._db = db + + def _new_connection(self) -> Any: + dbconn = self._db.give_me_a_connection() + dbconn.execute(text('BEGIN')) + conn = PostgresDbAPI(dbconn) + return conn + + def _commit(self) -> None: + self._connection.commit() + + def _rollback(self) -> None: + self._connection.rollback() + + def _release_connection(self) -> None: + self._connection._connection.close() + self._connection._connection = None + + +class IndexResourceAddIn: + @contextmanager + def _db_connection(self, transaction: bool = False) -> PostgresDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + In Resource Manager code replace self._db.connect() with self.db_connection(), and replace + self._db.begin() with self.db_connection(transaction=True). + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + with self._index._active_connection(transaction=transaction) as conn: + yield conn diff --git a/datacube/index/postgres/_users.py b/datacube/index/postgres/_users.py index f9cfd14152..2dd4c81877 100644 --- a/datacube/index/postgres/_users.py +++ b/datacube/index/postgres/_users.py @@ -1,24 +1,29 @@ # This file is part of the Open Data Cube, see https://opendatacube.org for more information # -# Copyright (c) 2015-2020 ODC Contributors +# Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 from typing import Iterable, Optional, Tuple from datacube.index.abstract import AbstractUserResource +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.drivers.postgres import PostgresDb -class UserResource(AbstractUserResource): - def __init__(self, db: PostgresDb) -> None: +class UserResource(AbstractUserResource, IndexResourceAddIn): + def __init__(self, + db: PostgresDb, + index: "datacube.index.postgres.index.Index" # noqa: F821 + ) -> None: """ :type db: datacube.drivers.postgres._connections.PostgresDb """ self._db = db + self._index = index def grant_role(self, role: str, *usernames: str) -> None: """ Grant a role to users """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.grant_role(role, usernames) def create_user(self, username: str, password: str, @@ -26,14 +31,14 @@ def create_user(self, username: str, password: str, """ Create a new user. """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.create_user(username, password, role, description=description) def delete_user(self, *usernames: str) -> None: """ Delete a user """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.drop_users(usernames) def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: @@ -41,6 +46,6 @@ def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: :return: list of (role, user, description) :rtype: list[(str, str, str)] """ - with self._db.connect() as connection: + with self._db_connection() as connection: for role, user, description in connection.list_users(): yield role, user, description diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index 16e4daba10..d56016bf6b 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -3,13 +3,15 @@ # Copyright (c) 2015-2020 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging +from contextlib import contextmanager -from datacube.drivers.postgres import PostgresDb +from datacube.drivers.postgres import PostgresDb, PostgresDbAPI +from datacube.index.postgres._transaction import PostgresTransaction from datacube.index.postgres._datasets import DatasetResource # type: ignore from datacube.index.postgres._metadata_types import MetadataTypeResource from datacube.index.postgres._products import ProductResource from datacube.index.postgres._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs, AbstractTransaction from datacube.model import MetadataType from datacube.utils.geometry import CRS @@ -40,13 +42,15 @@ class Index(AbstractIndex): :type metadata_types: datacube.index._metadata_types.MetadataTypeResource """ + supports_transactions = True + def __init__(self, db: PostgresDb) -> None: self._db = db - self._users = UserResource(db) - self._metadata_types = MetadataTypeResource(db) - self._products = ProductResource(db, self.metadata_types) - self._datasets = DatasetResource(db, self.products) + self._users = UserResource(db, self) + self._metadata_types = MetadataTypeResource(db, self) + self._products = ProductResource(db, self) + self._datasets = DatasetResource(db, self) @property def users(self) -> UserResource: @@ -99,12 +103,60 @@ def close(self): """ self._db.close() + @property + def index_id(self) -> str: + return f"legacy_{self.url}" + + def transaction(self) -> AbstractTransaction: + return PostgresTransaction(self._db, self.index_id) + def create_spatial_index(self, crs: CRS) -> None: _LOG.warning("postgres driver does not support spatio-temporal indexes") def __repr__(self): return "Index".format(self._db) + @contextmanager + def _active_connection(self, transaction: bool = False) -> PostgresDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + trans = self.thread_transaction() + closing = False + if trans is not None: + # Use active transaction + yield trans._connection + elif transaction: + closing = True + with self._db._connect() as conn: + conn.begin() + try: + yield conn + conn.commit() + except Exception: # pylint: disable=broad-except + conn.rollback() + raise + else: + closing = True + # Autocommit behaviour: + with self._db._connect() as conn: + yield conn + class DefaultIndexDriver(AbstractIndexDriver): aliases = ['postgres'] diff --git a/docs/about/whats_new.rst b/docs/about/whats_new.rst index 74c720703b..f902a14c9a 100644 --- a/docs/about/whats_new.rst +++ b/docs/about/whats_new.rst @@ -37,6 +37,8 @@ v1.8.next - Implement `patch_url` argument to `dc.load()` and `dc.load_data()` to provide a way to sign dataset URIs, as is required to access some commercial archives (e.g. Microsoft Planetary Computer). API is based on the `odc-stac` implementation. Only works for direct loading. More work required for deferred (i.e. Dask) loading. (:pull: `1317`) +- Implement public-facing index-driver-independent API for managing database transactions, as per Enhancement Proposal + EP07 (:pull: `1318`) - Update Conda environment to match dependencies in setup.py (:pull: `1319`) diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index cbd45619b3..1a834a0984 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -376,15 +376,6 @@ def index_empty(local_config, uninitialised_postgres_db: PostgresDb): del index -@pytest.fixture -def initialised_postgres_db(index): - """ - Return a connection to an PostgreSQL database, initialised with the default schema - and tables. - """ - return index._db - - def remove_postgres_dynamic_indexes(): """ Clear any dynamically created postgresql indexes from the schema. diff --git a/integration_tests/index/test_config_docs.py b/integration_tests/index/test_config_docs.py index 1be00ee5bb..328aab35bf 100644 --- a/integration_tests/index/test_config_docs.py +++ b/integration_tests/index/test_config_docs.py @@ -54,20 +54,20 @@ @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_metadata_indexes_views_exist(initialised_postgres_db, default_metadata_type): +def test_metadata_indexes_views_exist(index, default_metadata_type): """ :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb :type default_metadata_type: datacube.model.MetadataType """ # Metadata indexes should no longer exist. - assert not _object_exists(initialised_postgres_db, 'dix_eo_platform') + assert not _object_exists(index, 'dix_eo_platform') # Ensure view was created (following naming conventions) - assert _object_exists(initialised_postgres_db, 'dv_eo_dataset') + assert _object_exists(index, 'dv_eo_dataset') @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_dataset_indexes_views_exist(initialised_postgres_db, ls5_telem_type): +def test_dataset_indexes_views_exist(index, ls5_telem_type): """ :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb :type ls5_telem_type: datacube.model.DatasetType @@ -75,30 +75,30 @@ def test_dataset_indexes_views_exist(initialised_postgres_db, ls5_telem_type): assert ls5_telem_type.name == 'ls5_telem_test' # Ensure field indexes were created for the dataset type (following the naming conventions): - assert _object_exists(initialised_postgres_db, "dix_ls5_telem_test_orbit") + assert _object_exists(index, "dix_ls5_telem_test_orbit") # Ensure it does not create a 'platform' index, because that's a fixed field # (ie. identical in every dataset of the type) - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_platform") + assert not _object_exists(index, "dix_ls5_telem_test_platform") # Ensure view was created (following naming conventions) - assert _object_exists(initialised_postgres_db, 'dv_ls5_telem_test_dataset') + assert _object_exists(index, 'dv_ls5_telem_test_dataset') # Ensure view was created (following naming conventions) - assert not _object_exists(initialised_postgres_db, + assert not _object_exists(index, 'dix_ls5_telem_test_gsi'), "indexed=false field gsi shouldn't have an index" @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_dataset_composite_indexes_exist(initialised_postgres_db, ls5_telem_type): +def test_dataset_composite_indexes_exist(index, ls5_telem_type): # This type has fields named lat/lon/time, so composite indexes should now exist for them: # (following the naming conventions) - assert _object_exists(initialised_postgres_db, "dix_ls5_telem_test_sat_path_sat_row_time") + assert _object_exists(index, "dix_ls5_telem_test_sat_path_sat_row_time") # But no individual field indexes for these - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_sat_path") - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_sat_row") - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_time") + assert not _object_exists(index, "dix_ls5_telem_test_sat_path") + assert not _object_exists(index, "dix_ls5_telem_test_sat_row") + assert not _object_exists(index, "dix_ls5_telem_test_time") @pytest.mark.parametrize('datacube_env_name', ('datacube', )) @@ -209,12 +209,12 @@ def test_field_expression_unchanged_postgis( ) -def _object_exists(db, index_name): - if db.driver_name == "postgis": +def _object_exists(index, index_name): + if index._db.driver_name == "postgis": schema_name = "odc" else: schema_name = "agdc" - with db.connect() as connection: + with index._active_connection() as connection: val = connection._connection.execute(f"SELECT to_regclass('{schema_name}.{index_name}')").scalar() return val in (index_name, f'{schema_name}.{index_name}') @@ -337,11 +337,11 @@ def test_update_dataset_type(index, ls5_telem_type, ls5_telem_doc, ga_metadata_t index.products.update_document(full_doc) # Remove fixed field, forcing a new index to be created (as datasets can now differ for the field). - assert not _object_exists(index._db, 'dix_ls5_telem_test_product_type') + assert not _object_exists(index, 'dix_ls5_telem_test_product_type') del ls5_telem_doc['metadata']['product_type'] index.products.update_document(ls5_telem_doc) # Ensure was updated - assert _object_exists(index._db, 'dix_ls5_telem_test_product_type') + assert _object_exists(index, 'dix_ls5_telem_test_product_type') updated_type = index.products.get_by_name(ls5_telem_type.name) assert updated_type.definition['metadata'] == ls5_telem_doc['metadata'] @@ -548,7 +548,7 @@ def test_filter_types_by_search(index, ls5_telem_type): @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_update_metadata_type_doc(initialised_postgres_db, index, ls5_telem_type): +def test_update_metadata_type_doc(index, ls5_telem_type): type_doc = copy.deepcopy(ls5_telem_type.metadata_type.definition) type_doc['dataset']['search_fields']['test_indexed'] = { 'description': 'indexed test field', @@ -563,5 +563,5 @@ def test_update_metadata_type_doc(initialised_postgres_db, index, ls5_telem_type index.metadata_types.update_document(type_doc) assert ls5_telem_type.name == 'ls5_telem_test' - assert _object_exists(initialised_postgres_db, "dix_ls5_telem_test_test_indexed") - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_test_not_indexed") + assert _object_exists(index, "dix_ls5_telem_test_test_indexed") + assert not _object_exists(index, "dix_ls5_telem_test_test_not_indexed") diff --git a/integration_tests/index/test_index_data.py b/integration_tests/index/test_index_data.py index a41bc01fc0..a205a3836e 100755 --- a/integration_tests/index/test_index_data.py +++ b/integration_tests/index/test_index_data.py @@ -16,7 +16,6 @@ import pytest from dateutil import tz -from datacube.drivers.postgres import PostgresDb from datacube.index.exceptions import MissingRecordError from datacube.index import Index from datacube.model import Dataset, MetadataType @@ -71,10 +70,10 @@ } -def test_archive_datasets(index, initialised_postgres_db, local_config, default_metadata_type): +def test_archive_datasets(index, local_config, default_metadata_type): dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -106,11 +105,11 @@ def test_archive_datasets(index, initialised_postgres_db, local_config, default_ assert not indexed_dataset.is_archived -def test_purge_datasets(index, initialised_postgres_db, local_config, default_metadata_type, clirunner): +def test_purge_datasets(index, local_config, default_metadata_type, clirunner): # Create dataset dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -137,15 +136,15 @@ def test_purge_datasets(index, initialised_postgres_db, local_config, default_me assert index.datasets.get(_telemetry_uuid) is None -def test_purge_datasets_cli(index, initialised_postgres_db, local_config, default_metadata_type, clirunner): +def test_purge_datasets_cli(index, local_config, default_metadata_type, clirunner): dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) # Attempt to purge non-existent dataset should fail clirunner(['dataset', 'purge', str(_telemetry_uuid)], expect_success=False) # Create dataset - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -170,12 +169,12 @@ def test_purge_datasets_cli(index, initialised_postgres_db, local_config, defaul assert index.datasets.get(_telemetry_uuid) is None -def test_purge_all_datasets_cli(index, initialised_postgres_db, local_config, default_metadata_type, clirunner): +def test_purge_all_datasets_cli(index, local_config, default_metadata_type, clirunner): dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) # Create dataset - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -202,12 +201,12 @@ def test_purge_all_datasets_cli(index, initialised_postgres_db, local_config, de @pytest.fixture -def telemetry_dataset(index: Index, initialised_postgres_db: PostgresDb, default_metadata_type) -> Dataset: +def telemetry_dataset(index: Index, default_metadata_type) -> Dataset: dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) assert not index.datasets.has(_telemetry_uuid) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -217,14 +216,14 @@ def telemetry_dataset(index: Index, initialised_postgres_db: PostgresDb, default return index.datasets.get(_telemetry_uuid) -def test_index_duplicate_dataset(index: Index, initialised_postgres_db: PostgresDb, +def test_index_duplicate_dataset(index: Index, local_config, default_metadata_type) -> None: dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) assert not index.datasets.has(_telemetry_uuid) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -234,7 +233,7 @@ def test_index_duplicate_dataset(index: Index, initialised_postgres_db: Postgres assert index.datasets.has(_telemetry_uuid) # Insert again. - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: was_inserted = connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, @@ -273,28 +272,89 @@ def test_get_dataset(index: Index, telemetry_dataset: Dataset) -> None: 'f226a278-e422-11e6-b501-185e0f80a5c1']) == [] -def test_transactions(index: Index, - initialised_postgres_db: PostgresDb, - local_config, - default_metadata_type) -> None: - assert not index.datasets.has(_telemetry_uuid) - - dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( - _telemetry_dataset, - _telemetry_uuid, - dataset_type.id - ) - assert was_inserted - assert transaction.contains_dataset(_telemetry_uuid) - # Normal DB uses a separate connection: No dataset visible yet. - assert not index.datasets.has(_telemetry_uuid) - - transaction.rollback() - - # Should have been rolled back. - assert not index.datasets.has(_telemetry_uuid) +def test_transactions_api_ctx_mgr(index, + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): + from datacube.index.hl import Doc2Dataset + resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) + ds1, err = resolver(*eo3_ls8_dataset_doc) + ds2, err = resolver(*eo3_ls8_dataset2_doc) + with pytest.raises(Exception) as e: + with index.transaction() as trans: + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + raise Exception("Rollback!") + assert "Rollback!" in str(e.value) + assert index.datasets.get(ds1.id) is None + with index.transaction() as trans: + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds1.id) is not None + with index.transaction() as trans: + index.datasets.add(ds2) + assert index.datasets.get(ds2.id) is not None + raise trans.rollback_exception("Rollback") + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is None + + +def test_transactions_api_manual(index, + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): + from datacube.index.hl import Doc2Dataset + resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) + ds1, err = resolver(*eo3_ls8_dataset_doc) + ds2, err = resolver(*eo3_ls8_dataset2_doc) + trans = index.transaction() + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + trans.begin() + index.datasets.add(ds2) + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is not None + trans.rollback() + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is None + trans.begin() + index.datasets.add(ds2) + trans.commit() + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is not None + + +def test_transactions_api_hybrid(index, + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): + from datacube.index.hl import Doc2Dataset + resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) + ds1, err = resolver(*eo3_ls8_dataset_doc) + ds2, err = resolver(*eo3_ls8_dataset2_doc) + with index.transaction() as trans: + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + trans.rollback() + assert index.datasets.get(ds1.id) is None + trans.begin() + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + trans.commit() + assert index.datasets.get(ds1.id) is not None + trans.begin() + index.datasets.add(ds2) + assert index.datasets.get(ds2.id) is not None + trans.rollback() + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is None def test_get_missing_things(index: Index) -> None: diff --git a/integration_tests/index/test_memory_index.py b/integration_tests/index/test_memory_index.py index 97b276c140..a25a421e4a 100644 --- a/integration_tests/index/test_memory_index.py +++ b/integration_tests/index/test_memory_index.py @@ -552,3 +552,19 @@ def test_memory_dataset_add(dataset_add_configs, mem_index_fresh): ds_from_idx = idx.datasets.get(ds_.id, include_sources=True) assert ds_from_idx.sources['ab'].id == ds_.sources['ab'].id assert ds_from_idx.sources['ac'].sources["cd"].id == ds_.sources['ac'].sources['cd'].id + + +def test_mem_transactions(mem_index_fresh): + trans = mem_index_fresh.index.transaction() + assert not trans.active + trans.begin() + assert trans.active + trans.commit() + assert not trans.active + trans.begin() + assert mem_index_fresh.index.thread_transaction() == trans + with pytest.raises(ValueError): + trans.begin() + trans.rollback() + assert not trans.active + assert mem_index_fresh.index.thread_transaction() is None diff --git a/integration_tests/index/test_null_index.py b/integration_tests/index/test_null_index.py index 269018cbdb..8c235f0058 100644 --- a/integration_tests/index/test_null_index.py +++ b/integration_tests/index/test_null_index.py @@ -120,3 +120,20 @@ def test_null_dataset_resource(null_config): assert dc.index.datasets.search_summaries(foo="bar", baz=12) == [] assert dc.index.datasets.search_eager(foo="bar", baz=12) == [] assert dc.index.datasets.search_returning_datasets_light(("foo", "baz"), foo="bar", baz=12) == [] + + +def test_null_transactions(null_config): + with Datacube(config=null_config, validate_connection=True) as dc: + trans = dc.index.transaction() + assert not trans.active + trans.begin() + assert trans.active + trans.commit() + assert not trans.active + trans.begin() + assert dc.index.thread_transaction() == trans + with pytest.raises(ValueError): + trans.begin() + trans.rollback() + assert not trans.active + assert dc.index.thread_transaction() is None diff --git a/integration_tests/index/test_search_legacy.py b/integration_tests/index/test_search_legacy.py index 92e9b82e1a..d04ee2833d 100644 --- a/integration_tests/index/test_search_legacy.py +++ b/integration_tests/index/test_search_legacy.py @@ -18,7 +18,6 @@ from psycopg2._range import NumericRange from datacube.config import LocalConfig -from datacube.drivers.postgres import PostgresDb from datacube.drivers.postgres._connections import DEFAULT_DB_USER from datacube.index import Index from datacube.model import Dataset @@ -55,9 +54,9 @@ def pseudo_ls8_type(index, ga_metadata_type): @pytest.fixture -def pseudo_ls8_dataset(index, initialised_postgres_db, pseudo_ls8_type): +def pseudo_ls8_dataset(index, pseudo_ls8_type): id_ = str(uuid.uuid4()) - with initialised_postgres_db.connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( { 'id': id_, @@ -107,10 +106,10 @@ def pseudo_ls8_dataset(index, initialised_postgres_db, pseudo_ls8_type): @pytest.fixture -def pseudo_ls8_dataset2(index, initialised_postgres_db, pseudo_ls8_type): +def pseudo_ls8_dataset2(index, pseudo_ls8_type): # Like the previous dataset, but a day later in time. id_ = str(uuid.uuid4()) - with initialised_postgres_db.connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( { 'id': id_, @@ -162,7 +161,6 @@ def pseudo_ls8_dataset2(index, initialised_postgres_db, pseudo_ls8_type): # Datasets 3 and 4 mirror 1 and 2 but have a different path/row. @pytest.fixture def pseudo_ls8_dataset3(index: Index, - initialised_postgres_db: PostgresDb, pseudo_ls8_type: Product, pseudo_ls8_dataset: Dataset) -> Dataset: # Same as 1, but a different path/row @@ -174,7 +172,7 @@ def pseudo_ls8_dataset3(index: Index, 'satellite_ref_point_end': {'x': 116, 'y': 87}, } - with initialised_postgres_db.connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( dataset_doc, id_, @@ -189,7 +187,6 @@ def pseudo_ls8_dataset3(index: Index, @pytest.fixture def pseudo_ls8_dataset4(index: Index, - initialised_postgres_db: PostgresDb, pseudo_ls8_type: Product, pseudo_ls8_dataset2: Dataset) -> Dataset: # Same as 2, but a different path/row @@ -201,7 +198,7 @@ def pseudo_ls8_dataset4(index: Index, 'satellite_ref_point_end': {'x': 116, 'y': 87}, } - with initialised_postgres_db.connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( dataset_doc, id_, @@ -855,7 +852,7 @@ def test_cli_info(index: Index, assert yaml_docs[1]['id'] == str(pseudo_ls8_dataset2.id) -def test_cli_missing_info(clirunner, initialised_postgres_db): +def test_cli_missing_info(clirunner, index): id_ = str(uuid.uuid4()) result = clirunner( [ diff --git a/integration_tests/index/test_update_columns.py b/integration_tests/index/test_update_columns.py index 9e34575ae8..e64b518246 100644 --- a/integration_tests/index/test_update_columns.py +++ b/integration_tests/index/test_update_columns.py @@ -58,7 +58,7 @@ def test_added_column(clirunner, uninitialised_postgres_db): result = clirunner(["system", "init"]) assert "Created." in result.output - with uninitialised_postgres_db.connect() as connection: + with uninitialised_postgres_db._connect() as connection: assert check_column(connection, _schema.METADATA_TYPE.name, "updated") assert not check_column(connection, _schema.METADATA_TYPE.name, "fake_column") assert check_column(connection, _schema.PRODUCT.name, "updated") @@ -81,7 +81,7 @@ def test_readd_column(clirunner, uninitialised_postgres_db): result = clirunner(["system", "init"]) assert "Created." in result.output - with uninitialised_postgres_db.connect() as connection: + with uninitialised_postgres_db._connect() as connection: # Drop all the columns for an init rerun drop_column(connection, _schema.METADATA_TYPE.name, "updated") drop_column(connection, _schema.PRODUCT.name, "updated") @@ -95,7 +95,7 @@ def test_readd_column(clirunner, uninitialised_postgres_db): result = clirunner(["system", "init"]) - with uninitialised_postgres_db.connect() as connection: + with uninitialised_postgres_db._connect() as connection: assert check_column(connection, _schema.METADATA_TYPE.name, "updated") assert check_column(connection, _schema.PRODUCT.name, "updated") assert check_column(connection, _schema.DATASET.name, "updated") diff --git a/integration_tests/test_config_tool.py b/integration_tests/test_config_tool.py index d333058d3f..ef6cc5c53c 100644 --- a/integration_tests/test_config_tool.py +++ b/integration_tests/test_config_tool.py @@ -20,20 +20,18 @@ INVALID_MAPPING_DOCS = map(str, Path(__file__).parent.parent.joinpath('docs').glob('*')) -def _dataset_type_count(db): - with db.connect() as connection: +def _dataset_type_count(index): + with index._active_connection() as connection: return len(list(connection.get_all_products())) -def test_add_example_dataset_types(clirunner, initialised_postgres_db, default_metadata_type): +def test_add_example_dataset_types(clirunner, index, default_metadata_type): """ Add example mapping docs, to ensure they're valid and up-to-date. We add them all to a single database to check for things like duplicate ids. - - :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb """ - existing_mappings = _dataset_type_count(initialised_postgres_db) + existing_mappings = _dataset_type_count(index) print('{} mappings'.format(existing_mappings)) for mapping_path in EXAMPLE_DATASET_TYPE_DOCS: @@ -42,7 +40,7 @@ def test_add_example_dataset_types(clirunner, initialised_postgres_db, default_m result = clirunner(['-v', 'product', 'add', mapping_path]) assert result.exit_code == 0 - mappings_count = _dataset_type_count(initialised_postgres_db) + mappings_count = _dataset_type_count(index) assert mappings_count > existing_mappings, "Mapping document was not added: " + str(mapping_path) existing_mappings = mappings_count @@ -79,11 +77,8 @@ def test_add_example_dataset_types(clirunner, initialised_postgres_db, default_m assert result.exit_code == 0 -def test_error_returned_on_invalid(clirunner, initialised_postgres_db): - """ - :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb - """ - assert _dataset_type_count(initialised_postgres_db) == 0 +def test_error_returned_on_invalid(clirunner, index): + assert _dataset_type_count(index) == 0 for mapping_path in INVALID_MAPPING_DOCS: result = clirunner( @@ -95,10 +90,10 @@ def test_error_returned_on_invalid(clirunner, initialised_postgres_db): expect_success=False ) assert result.exit_code != 0, "Success return code for invalid document." - assert _dataset_type_count(initialised_postgres_db) == 0, "Invalid document was added to DB" + assert _dataset_type_count(index) == 0, "Invalid document was added to DB" -def test_config_check(clirunner, initialised_postgres_db, local_config): +def test_config_check(clirunner, index, local_config): """ :type local_config: datacube.config.LocalConfig """ @@ -119,7 +114,7 @@ def test_config_check(clirunner, initialised_postgres_db, local_config): assert user_regex.match(result.output) -def test_list_users_does_not_fail(clirunner, local_config, initialised_postgres_db): +def test_list_users_does_not_fail(clirunner, local_config, index): """ :type local_config: datacube.config.LocalConfig """ @@ -178,30 +173,30 @@ def test_db_init_rebuild(clirunner, local_config, ls5_telem_type): ) in result.output -def test_db_init(clirunner, initialised_postgres_db): - if initialised_postgres_db.driver_name == "postgis": +def test_db_init(clirunner, index): + if index._db.driver_name == "postgis": from datacube.drivers.postgis._core import drop_db, has_schema else: from datacube.drivers.postgres._core import drop_db, has_schema - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: drop_db(connection._connection) - assert not has_schema(initialised_postgres_db._engine, connection._connection) + assert not has_schema(index._db._engine, connection._connection) # Run on an empty database. - if initialised_postgres_db.driver_name == "postgis": + if index._db.driver_name == "postgis": result = clirunner(['-E', 'experimental', 'system', 'init']) else: result = clirunner(['system', 'init']) assert 'Created.' in result.output - with initialised_postgres_db.connect() as connection: - assert has_schema(initialised_postgres_db._engine, connection._connection) + with index._db._connect() as connection: + assert has_schema(index._db._engine, connection._connection) -def test_add_no_such_product(clirunner, initialised_postgres_db): +def test_add_no_such_product(clirunner, index): result = clirunner(['dataset', 'add', '--dtype', 'no_such_product', '/tmp'], expect_success=False) assert result.exit_code != 0 assert "DEPRECATED option detected" in result.output @@ -214,13 +209,13 @@ def test_add_no_such_product(clirunner, initialised_postgres_db): # Test that names are escaped ('test_user_"invalid+_chars_{n}', None), ('test_user_invalid_desc_{n}', 'Invalid "\' chars in description')]) -def example_user(clirunner, initialised_postgres_db, request): +def example_user(clirunner, index, request): username, description = request.param username = username.format(n=random.randint(111111, 999999)) # test_roles = (user_name for role_name, user_name, desc in roles if user_name.startswith('test_')) - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: users = (user_name for role_name, user_name, desc in connection.list_users()) if username in users: connection.drop_users([username]) @@ -230,7 +225,7 @@ def example_user(clirunner, initialised_postgres_db, request): yield username, description - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: users = (user_name for role_name, user_name, desc in connection.list_users()) if username in users: connection.drop_users([username]) diff --git a/tests/index/test_api_index_dataset.py b/tests/index/test_api_index_dataset.py index 9951e69a16..93a57a84ee 100644 --- a/tests/index/test_api_index_dataset.py +++ b/tests/index/test_api_index_dataset.py @@ -155,23 +155,23 @@ def _build_dataset(doc): 'added', 'added_by', 'archived']) -class MockIndex(object): - def __init__(self, db): - self._db = db - - class MockDb(object): def __init__(self): self.dataset = {} self.dataset_source = set() @contextmanager - def begin(self): + def _connect(self): yield self - @contextmanager - def connect(self): - yield self + def begin(self): + pass + + def commit(self): + pass + + def rollback(self): + pass def get_dataset(self, id): return self.dataset.get(id, None) @@ -198,7 +198,7 @@ def insert_dataset_source(self, classifier, dataset_id, source_dataset_id): self.dataset_source.add((classifier, dataset_id, source_dataset_id)) -class MockTypesResource(object): +class MockTypesResource: def __init__(self, type_): self.type = type_ @@ -208,11 +208,28 @@ def get(self, *args, **kwargs): def get_by_name(self, *args, **kwargs): return self.type + @contextmanager + def _db_connection(self, transaction=False): + yield MockDb() + + +class MockIndex: + def __init__(self, db, product): + self._db = db + self.products = MockTypesResource(product) + + def thread_transaction(self): + return None + + @contextmanager + def _active_connection(self, transaction=False): + yield self._db + def test_index_dataset(): mock_db = MockDb() - mock_types = MockTypesResource(_EXAMPLE_DATASET_TYPE) - datasets = DatasetResource(mock_db, mock_types) + mock_index = MockIndex(mock_db, _EXAMPLE_DATASET_TYPE) + datasets = DatasetResource(mock_db, mock_index) dataset = datasets.add(_EXAMPLE_NBAR_DATASET) ids = {d.id for d in mock_db.dataset.values()} @@ -237,8 +254,8 @@ def test_index_dataset(): def test_index_already_ingested_source_dataset(): mock_db = MockDb() - mock_types = MockTypesResource(_EXAMPLE_DATASET_TYPE) - datasets = DatasetResource(mock_db, mock_types) + mock_index = MockIndex(mock_db, _EXAMPLE_DATASET_TYPE) + datasets = DatasetResource(mock_db, mock_index) dataset = datasets.add(_EXAMPLE_NBAR_DATASET.sources['ortho']) assert len(mock_db.dataset) == 2 @@ -251,8 +268,8 @@ def test_index_already_ingested_source_dataset(): def test_index_two_levels_already_ingested(): mock_db = MockDb() - mock_types = MockTypesResource(_EXAMPLE_DATASET_TYPE) - datasets = DatasetResource(mock_db, mock_types) + mock_index = MockIndex(mock_db, _EXAMPLE_DATASET_TYPE) + datasets = DatasetResource(mock_db, mock_index) dataset = datasets.add(_EXAMPLE_NBAR_DATASET.sources['ortho'].sources['satellite_telemetry_data']) assert len(mock_db.dataset) == 1