From 2a5352e90b221914a1d59ee606f58e6c96c24338 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 9 Sep 2022 11:33:45 +1000 Subject: [PATCH 01/18] Resolve rebase conflicts --- datacube/index/abstract.py | 193 +++++++++++++++++++++++++++++-- datacube/index/exceptions.py | 6 + datacube/index/memory/index.py | 14 ++- datacube/index/null/index.py | 5 +- datacube/index/postgis/index.py | 7 +- datacube/index/postgres/index.py | 8 +- 6 files changed, 220 insertions(+), 13 deletions(-) diff --git a/datacube/index/abstract.py b/datacube/index/abstract.py index 52636465ea..799cb65b33 100644 --- a/datacube/index/abstract.py +++ b/datacube/index/abstract.py @@ -5,6 +5,7 @@ import datetime import logging from pathlib import Path +from threading import Lock from abc import ABC, abstractmethod from typing import (Any, Iterable, Iterator, @@ -13,11 +14,13 @@ from uuid import UUID from datacube.config import LocalConfig +from datacube.index.exceptions import TransactionException from datacube.index.fields import Field from datacube.model import Dataset, MetadataType, Range from datacube.model import DatasetType as Product from datacube.utils import cached_property, read_documents, InvalidDocException from datacube.utils.changes import AllowPolicy, Change, Offset +from datacube.utils.generic import thread_local_cache from datacube.utils.geometry import CRS, Geometry, box _LOG = logging.getLogger(__name__) @@ -988,6 +991,162 @@ def _extract_geom_from_query(self, q: Mapping[str, QueryField]) -> Optional[Geom return geom +class AbstractTransaction(ABC): + """ + Abstract base class for a Transaction Manager. All index implementations should extend this base class. + + Thread-local storage and locks ensures one active transaction per index per thread. + """ + + def __init__(self, index_id: str): + self._connection: Any = None + self.tls_id = f"txn-{index_id}" + self.obj_lock = Lock() + + # Main Transaction API + def begin(self) -> None: + """ + Start a new transaction. + + Raises an error if a transaction is already active for this thread. + + Calls implementation-specific _new_connection() method and manages thread local storage and locks. + """ + with self.obj_lock: + if self._connection is not None: + raise ValueError("Cannot start a new transaction as one is already active") + self._tls_stash() + + def commit(self) -> None: + """ + Commit the transaction. + + Raises an error if transaction is not active. + + Calls implementation-specific _commit() method, and manages thread local storage and locks. + """ + with self.lock: + if self._connection is not None: + raise ValueError("Cannot commit inactive transaction") + self._commit() + self._release_connection() + self._connection = None + self._tls_purge() + + def rollback(self) -> None: + """ + Rollback the transaction. + + Raises an error if transaction is not active. + + Calls implementation-specific _rollback() method, and manages thread local storage and locks. + """ + with self.lock: + if self._connection is not None: + raise ValueError("Cannot rollback inactive transaction") + self._rollback() + self._release_connection() + self._connection = None + self._tls_purge() + + @property + def active(self): + """ + :return: True if the transaction is active. + """ + return self._connection is not None + + # Manage thread-local storage + def _tls_stash(self) -> None: + """ + Check TLS is empty, create a new connection and stash it. + :return: + """ + stored_val = thread_local_cache(self.tls_id) + if stored_val is not None: + raise ValueError("Cannot start a new transaction as one is already active for this thread") + self._connection = self._new_connection() + thread_local_cache(self.tls_id, self, purge=True) + + def _tls_purge(self) -> None: + thread_local_cache(self.tls_id, purge=True) + + @classmethod + def thread_transaction(cls, index_id: str) -> "AbstractTransaction": + return thread_local_cache(f"txn-{index_id}", None) + + # Commit/Rollback exceptions for Context Manager usage patterns + def commit_exception(self, errmsg: str) -> TransactionException: + return TransactionException(errmsg, commit=True) + + def rollback_exception(self, errmsg: str) -> TransactionException: + return TransactionException(errmsg, commit=False) + + # Context Manager Interface + def __enter__(self): + self.begin() + + def __exit__(self, exc_type, exc_value, traceback): + if issubclass(exc_type, TransactionException): + # User raised a TransactionException, Commit or rollback as per exception + if exc_value.commit: + self.commit() + else: + self.rollback() + # Tell runtime exception is caught and handled. + return True + elif exc_value is not None: + # Any other exception - rollback + self.rollback() + # Instruct runtime to rethrow exception + return False + else: + # Exited without exception - commit and continue + self.commit() + return True + + # Internal abstract methods for implementation-specific functionality + @abstractmethod + def _new_connection(self) -> Any: + """ + :return: a new index driver object representing a database connection or equivalent against which transactions + will be executed. + """ + + @abstractmethod + def _commit(self) -> None: + """ + Commit the transaction. + """ + + @abstractmethod + def _rollback(self) -> None: + """ + Rollback the transaction. + """ + + @abstractmethod + def _release_connection(self) -> None: + """ + Release the connection object stored in self._connection + """ + + +class UnhandledTransaction(AbstractTransaction): + # Minimal implementation for index drivers with no transaction handling. + def _new_connection(self) -> Any: + return True + + def _commit(self) -> None: + pass + + def _rollback(self) -> None: + pass + + def _release_connection(self) -> None: + pass + + class AbstractIndex(ABC): """ Abstract base class for an Index. All Index implementations should @@ -1004,31 +1163,33 @@ class AbstractIndex(ABC): # supports lineage supports_lineage = True supports_source_filters = True + # Supports ACID transactions + supports_transactions = False @property @abstractmethod def url(self) -> str: - ... + """A string representing the index""" @property @abstractmethod def users(self) -> AbstractUserResource: - ... + """A User Resource instance for the index""" @property @abstractmethod def metadata_types(self) -> AbstractMetadataTypeResource: - ... + """A MetadataType Resource instance for the index""" @property @abstractmethod def products(self) -> AbstractProductResource: - ... + """A Product Resource instance for the index""" @property @abstractmethod def datasets(self) -> AbstractDatasetResource: - ... + """A Dataset Resource instance for the index""" @classmethod @abstractmethod @@ -1037,24 +1198,38 @@ def from_config(cls, application_name: Optional[str] = None, validate_connection: bool = True ) -> "AbstractIndex": - ... + """Instantiate a new index from a LocalConfig object""" @classmethod @abstractmethod def get_dataset_fields(cls, doc: dict ) -> Mapping[str, Field]: - ... + """Return dataset search fields from a metadata type document""" @abstractmethod def init_db(self, with_default_types: bool = True, with_permissions: bool = True) -> bool: - ... + """ + Initialise an empty database. + + :param with_default_types: Whether to create default metadata types + :param with_permissions: Whether to create db permissions + :return: true if the database was created, false if already exists + """ @abstractmethod def close(self) -> None: - ... + """ + Close and cleanup the Index. + """ + + @abstractmethod + def transaction(self) -> AbstractTransaction: + """ + :return: a Transaction context manager for this index. + """ @abstractmethod def create_spatial_index(self, crs: CRS) -> bool: diff --git a/datacube/index/exceptions.py b/datacube/index/exceptions.py index 9e731c0b9f..282a6a59f0 100644 --- a/datacube/index/exceptions.py +++ b/datacube/index/exceptions.py @@ -14,3 +14,9 @@ class MissingRecordError(Exception): class IndexSetupError(Exception): pass + + +class TransactionException(Exception): + def __init__(self, *args, commit=False, **kwargs): + super().__init__(*args, **kwargs) + self.commit = commit diff --git a/datacube/index/memory/index.py b/datacube/index/memory/index.py index fd2045c04a..55016a83fe 100644 --- a/datacube/index/memory/index.py +++ b/datacube/index/memory/index.py @@ -3,19 +3,24 @@ # Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging +from threading import Lock from datacube.index.memory._datasets import DatasetResource # type: ignore from datacube.index.memory._fields import get_dataset_fields from datacube.index.memory._metadata_types import MetadataTypeResource from datacube.index.memory._products import ProductResource from datacube.index.memory._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, UnhandledTransaction from datacube.model import MetadataType from datacube.utils.geometry import CRS _LOG = logging.getLogger(__name__) +counter = 0 +counter_lock = Lock() + + class Index(AbstractIndex): """ Lightweight in-memory index driver @@ -26,6 +31,10 @@ def __init__(self) -> None: self._metadata_types = MetadataTypeResource() self._products = ProductResource(self.metadata_types) self._datasets = DatasetResource(self.products) + global counter + with counter_lock: + counter = counter + 1 + self._index_id = f"memory={counter}" @property def users(self) -> UserResource: @@ -47,6 +56,9 @@ def datasets(self) -> DatasetResource: def url(self) -> str: return "memory" + def transaction(self) -> UnhandledTransaction: + return UnhandledTransaction(self._index_id) + @classmethod def from_config(cls, config, application_name=None, validate_connection=True): return cls() diff --git a/datacube/index/null/index.py b/datacube/index/null/index.py index 36a94b8783..4518933ea8 100644 --- a/datacube/index/null/index.py +++ b/datacube/index/null/index.py @@ -8,7 +8,7 @@ from datacube.index.null._metadata_types import MetadataTypeResource from datacube.index.null._products import ProductResource from datacube.index.null._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, UnhandledTransaction from datacube.model import MetadataType from datacube.model.fields import get_dataset_fields from datacube.utils.geometry import CRS @@ -49,6 +49,9 @@ def datasets(self) -> DatasetResource: def url(self) -> str: return "null" + def transaction(self) -> UnhandledTransaction: + return UnhandledTransaction("null") + @classmethod def from_config(cls, config, application_name=None, validate_connection=True): return cls() diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index 6b03f4bcf5..a2b3581835 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -10,7 +10,7 @@ from datacube.index.postgis._metadata_types import MetadataTypeResource from datacube.index.postgis._products import ProductResource from datacube.index.postgis._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs, AbstractTransaction from datacube.model import MetadataType from datacube.utils.geometry import CRS @@ -48,6 +48,7 @@ class Index(AbstractIndex): # Hopefully can reinstate a simpler form of lineage support, but leave for now supports_lineage = False supports_source_filters = False + supports_transactions = True def __init__(self, db: PostGisDb) -> None: # POSTGIS driver is not stable with respect to database schema or internal APIs. @@ -115,6 +116,10 @@ def close(self): """ self._db.close() + def transaction(self) -> AbstractTransaction: + # TODO + return None + def create_spatial_index(self, crs: CRS) -> bool: sp_idx = self._db.create_spatial_index(crs) return sp_idx is not None diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index 16e4daba10..bb2f9ece42 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -9,7 +9,7 @@ from datacube.index.postgres._metadata_types import MetadataTypeResource from datacube.index.postgres._products import ProductResource from datacube.index.postgres._users import UserResource -from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs +from datacube.index.abstract import AbstractIndex, AbstractIndexDriver, default_metadata_type_docs, AbstractTransaction from datacube.model import MetadataType from datacube.utils.geometry import CRS @@ -40,6 +40,8 @@ class Index(AbstractIndex): :type metadata_types: datacube.index._metadata_types.MetadataTypeResource """ + supports_transactions = True + def __init__(self, db: PostgresDb) -> None: self._db = db @@ -99,6 +101,10 @@ def close(self): """ self._db.close() + def transaction(self) -> AbstractTransaction: + # TODO + return None + def create_spatial_index(self, crs: CRS) -> None: _LOG.warning("postgres driver does not support spatio-temporal indexes") From 39f11cadc78606b4aadb23c7c862f54321472389 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 9 Sep 2022 12:01:14 +1000 Subject: [PATCH 02/18] Move thread-local transaction finder from AbstractTransaction classmethod to AbstractIndex method. --- datacube/index/abstract.py | 18 ++++++++++++++---- datacube/index/memory/index.py | 6 +++++- datacube/index/null/index.py | 6 +++++- datacube/index/postgis/index.py | 4 ++++ datacube/index/postgres/index.py | 4 ++++ 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/datacube/index/abstract.py b/datacube/index/abstract.py index 799cb65b33..fe9c1b2c15 100644 --- a/datacube/index/abstract.py +++ b/datacube/index/abstract.py @@ -1071,10 +1071,6 @@ def _tls_stash(self) -> None: def _tls_purge(self) -> None: thread_local_cache(self.tls_id, purge=True) - @classmethod - def thread_transaction(cls, index_id: str) -> "AbstractTransaction": - return thread_local_cache(f"txn-{index_id}", None) - # Commit/Rollback exceptions for Context Manager usage patterns def commit_exception(self, errmsg: str) -> TransactionException: return TransactionException(errmsg, commit=True) @@ -1225,6 +1221,14 @@ def close(self) -> None: Close and cleanup the Index. """ + @property + @abstractmethod + def index_id(self) -> str: + """ + :return: Unique ID for this index + (e.g. same database/dataset storage + same index driver implementation = same id) + """ + @abstractmethod def transaction(self) -> AbstractTransaction: """ @@ -1241,6 +1245,12 @@ def create_spatial_index(self, crs: CRS) -> bool: None if spatial indexes are not supported. """ + def thread_transaction(self) -> Optional["AbstractTransaction"]: + """ + :return: The existing Transaction object cached in thread-local storage for this index, if there is one. + """ + return thread_local_cache(f"txn-{self.index_id}", None) + def spatial_indexes(self, refresh=False) -> Iterable[CRS]: """ Return a list of CRSs for which spatiotemporal indexes exist in the database. diff --git a/datacube/index/memory/index.py b/datacube/index/memory/index.py index 55016a83fe..1a1385ffe3 100644 --- a/datacube/index/memory/index.py +++ b/datacube/index/memory/index.py @@ -56,8 +56,12 @@ def datasets(self) -> DatasetResource: def url(self) -> str: return "memory" + @property + def index_id(self) -> str: + return self._index_id + def transaction(self) -> UnhandledTransaction: - return UnhandledTransaction(self._index_id) + return UnhandledTransaction(self.index_id) @classmethod def from_config(cls, config, application_name=None, validate_connection=True): diff --git a/datacube/index/null/index.py b/datacube/index/null/index.py index 4518933ea8..a678f7998a 100644 --- a/datacube/index/null/index.py +++ b/datacube/index/null/index.py @@ -49,8 +49,12 @@ def datasets(self) -> DatasetResource: def url(self) -> str: return "null" + @property + def index_id(self) -> str: + return "null" + def transaction(self) -> UnhandledTransaction: - return UnhandledTransaction("null") + return UnhandledTransaction(self.index_id) @classmethod def from_config(cls, config, application_name=None, validate_connection=True): diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index a2b3581835..d27f9e94f4 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -116,6 +116,10 @@ def close(self): """ self._db.close() + @property + def index_id(self) -> str: + return self.url + def transaction(self) -> AbstractTransaction: # TODO return None diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index bb2f9ece42..8ef0f7efae 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -101,6 +101,10 @@ def close(self): """ self._db.close() + @property + def index_id(self) -> str: + return f"legacy_{self.url}" + def transaction(self) -> AbstractTransaction: # TODO return None From 051a1dc0fe82aa66679e38657478e2ee5bb77b76 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 9 Sep 2022 12:21:50 +1000 Subject: [PATCH 03/18] Test trivial transaction implementations (memory and null). --- datacube/index/abstract.py | 11 ++++++----- integration_tests/index/test_memory_index.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/datacube/index/abstract.py b/datacube/index/abstract.py index fe9c1b2c15..3e9442d80d 100644 --- a/datacube/index/abstract.py +++ b/datacube/index/abstract.py @@ -1025,8 +1025,8 @@ def commit(self) -> None: Calls implementation-specific _commit() method, and manages thread local storage and locks. """ - with self.lock: - if self._connection is not None: + with self.obj_lock: + if self._connection is None: raise ValueError("Cannot commit inactive transaction") self._commit() self._release_connection() @@ -1041,8 +1041,8 @@ def rollback(self) -> None: Calls implementation-specific _rollback() method, and manages thread local storage and locks. """ - with self.lock: - if self._connection is not None: + with self.obj_lock: + if self._connection is None: raise ValueError("Cannot rollback inactive transaction") self._rollback() self._release_connection() @@ -1066,7 +1066,8 @@ def _tls_stash(self) -> None: if stored_val is not None: raise ValueError("Cannot start a new transaction as one is already active for this thread") self._connection = self._new_connection() - thread_local_cache(self.tls_id, self, purge=True) + thread_local_cache(self.tls_id, purge=True) + thread_local_cache(self.tls_id, self) def _tls_purge(self) -> None: thread_local_cache(self.tls_id, purge=True) diff --git a/integration_tests/index/test_memory_index.py b/integration_tests/index/test_memory_index.py index 97b276c140..a25a421e4a 100644 --- a/integration_tests/index/test_memory_index.py +++ b/integration_tests/index/test_memory_index.py @@ -552,3 +552,19 @@ def test_memory_dataset_add(dataset_add_configs, mem_index_fresh): ds_from_idx = idx.datasets.get(ds_.id, include_sources=True) assert ds_from_idx.sources['ab'].id == ds_.sources['ab'].id assert ds_from_idx.sources['ac'].sources["cd"].id == ds_.sources['ac'].sources['cd'].id + + +def test_mem_transactions(mem_index_fresh): + trans = mem_index_fresh.index.transaction() + assert not trans.active + trans.begin() + assert trans.active + trans.commit() + assert not trans.active + trans.begin() + assert mem_index_fresh.index.thread_transaction() == trans + with pytest.raises(ValueError): + trans.begin() + trans.rollback() + assert not trans.active + assert mem_index_fresh.index.thread_transaction() is None From c8873a37241bfdab20f69024f08c1b57b7c58840 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 9 Sep 2022 14:15:34 +1000 Subject: [PATCH 04/18] Test trivial transaction implementations (memory and null). --- integration_tests/index/test_null_index.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/integration_tests/index/test_null_index.py b/integration_tests/index/test_null_index.py index 269018cbdb..8c235f0058 100644 --- a/integration_tests/index/test_null_index.py +++ b/integration_tests/index/test_null_index.py @@ -120,3 +120,20 @@ def test_null_dataset_resource(null_config): assert dc.index.datasets.search_summaries(foo="bar", baz=12) == [] assert dc.index.datasets.search_eager(foo="bar", baz=12) == [] assert dc.index.datasets.search_returning_datasets_light(("foo", "baz"), foo="bar", baz=12) == [] + + +def test_null_transactions(null_config): + with Datacube(config=null_config, validate_connection=True) as dc: + trans = dc.index.transaction() + assert not trans.active + trans.begin() + assert trans.active + trans.commit() + assert not trans.active + trans.begin() + assert dc.index.thread_transaction() == trans + with pytest.raises(ValueError): + trans.begin() + trans.rollback() + assert not trans.active + assert dc.index.thread_transaction() is None From 43bdd4390c63c36ffabfa9c277527f24d3d9e2a1 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 9 Sep 2022 15:11:48 +1000 Subject: [PATCH 05/18] Implement transaction API for postgres and postgis drivers. (Not honoured by API methods yet.) --- datacube/drivers/postgis/_api.py | 3 +++ datacube/drivers/postgres/_api.py | 3 +++ datacube/index/postgis/index.py | 23 ++++++++++++++++++++--- datacube/index/postgres/index.py | 22 ++++++++++++++++++++-- 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/datacube/drivers/postgis/_api.py b/datacube/drivers/postgis/_api.py index afc2b26edf..49f758b95c 100644 --- a/datacube/drivers/postgis/_api.py +++ b/datacube/drivers/postgis/_api.py @@ -182,6 +182,9 @@ def __init__(self, parentdb, connection): def in_transaction(self): return self._connection.in_transaction() + def commit(self): + self._connection.execute(text('COMMIT')) + def rollback(self): self._connection.execute(text('ROLLBACK')) diff --git a/datacube/drivers/postgres/_api.py b/datacube/drivers/postgres/_api.py index 178b39faa9..36c3c54408 100644 --- a/datacube/drivers/postgres/_api.py +++ b/datacube/drivers/postgres/_api.py @@ -185,6 +185,9 @@ def in_transaction(self): def rollback(self): self._connection.execute(text('ROLLBACK')) + def commit(self): + self._connection.execute(text('COMMIT')) + def execute(self, command): return self._connection.execute(command) diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index d27f9e94f4..9a016f547f 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -3,7 +3,7 @@ # Copyright (c) 2015-2020 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Iterable, Sequence +from typing import Any, Iterable, Sequence from datacube.drivers.postgis import PostGisDb from datacube.index.postgis._datasets import DatasetResource, DSID # type: ignore @@ -17,6 +17,24 @@ _LOG = logging.getLogger(__name__) +class PostgisTransaction(AbstractTransaction): + def __init__(self, db: PostGisDb, idx_id: str) -> None: + super().__init__(idx_id) + self._db = db + + def _new_connection(self) -> Any: + return self._db.begin() + + def _commit(self) -> None: + self._connection.commit() + + def _rollback(self) -> None: + self._connection.rollback() + + def _release_connection(self) -> None: + self._connection.close() + + class Index(AbstractIndex): """ Access to the datacube index. @@ -121,8 +139,7 @@ def index_id(self) -> str: return self.url def transaction(self) -> AbstractTransaction: - # TODO - return None + return PostgisTransaction(self._db, self.index_id) def create_spatial_index(self, crs: CRS) -> bool: sp_idx = self._db.create_spatial_index(crs) diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index 8ef0f7efae..8f09f4b2de 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -3,6 +3,7 @@ # Copyright (c) 2015-2020 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging +from typing import Any from datacube.drivers.postgres import PostgresDb from datacube.index.postgres._datasets import DatasetResource # type: ignore @@ -16,6 +17,24 @@ _LOG = logging.getLogger(__name__) +class PostgresTransaction(AbstractTransaction): + def __init__(self, db: PostgresDb, idx_id: str) -> None: + super().__init__(idx_id) + self._db = db + + def _new_connection(self) -> Any: + return self._db.begin() + + def _commit(self) -> None: + self._connection.commit() + + def _rollback(self) -> None: + self._connection.rollback() + + def _release_connection(self) -> None: + self._connection.close() + + class Index(AbstractIndex): """ Access to the datacube index. @@ -106,8 +125,7 @@ def index_id(self) -> str: return f"legacy_{self.url}" def transaction(self) -> AbstractTransaction: - # TODO - return None + return PostgresTransaction(self._db, self.index_id) def create_spatial_index(self, crs: CRS) -> None: _LOG.warning("postgres driver does not support spatio-temporal indexes") From 7e4f65fa430e8e7cc421564c50e26d02bea01f5f Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Mon, 12 Sep 2022 12:51:35 +1000 Subject: [PATCH 06/18] Give postgis and postgres index resources links to the index object. --- datacube/index/postgis/_datasets.py | 7 ++++--- datacube/index/postgis/_metadata_types.py | 3 ++- datacube/index/postgis/_products.py | 5 +++-- datacube/index/postgis/_users.py | 3 ++- datacube/index/postgis/index.py | 8 ++++---- datacube/index/postgres/_datasets.py | 5 +++-- datacube/index/postgres/_metadata_types.py | 3 ++- datacube/index/postgres/_products.py | 5 +++-- datacube/index/postgres/_users.py | 3 ++- datacube/index/postgres/index.py | 8 ++++---- tests/index/test_api_index_dataset.py | 20 +++++++++++++------- 11 files changed, 42 insertions(+), 28 deletions(-) diff --git a/datacube/index/postgis/_datasets.py b/datacube/index/postgis/_datasets.py index 7e43e7f210..5903c5f242 100755 --- a/datacube/index/postgis/_datasets.py +++ b/datacube/index/postgis/_datasets.py @@ -38,14 +38,15 @@ class DatasetResource(AbstractDatasetResource): :type types: datacube.index._products.ProductResource """ - def __init__(self, db, product_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgis._connections.PostgresDb :type product_resource: datacube.index._products.ProductResource """ self._db = db - self.types = product_resource - self.products = product_resource + self._index = index + self.types = self._index.products # types is a compatibility alias for products. + self.products = self._index.products def get(self, id_: Union[str, UUID], include_sources=False): """ diff --git a/datacube/index/postgis/_metadata_types.py b/datacube/index/postgis/_metadata_types.py index 94556b07d6..809f95c3b8 100644 --- a/datacube/index/postgis/_metadata_types.py +++ b/datacube/index/postgis/_metadata_types.py @@ -15,11 +15,12 @@ class MetadataTypeResource(AbstractMetadataTypeResource): - def __init__(self, db): + def __init__(self, db, index): """ :type db: datacube.drivers.postgis._connections.PostgresDb """ self._db = db + self._index = index self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) diff --git a/datacube/index/postgis/_products.py b/datacube/index/postgis/_products.py index 2f380258b7..3ad45b5f3c 100644 --- a/datacube/index/postgis/_products.py +++ b/datacube/index/postgis/_products.py @@ -23,13 +23,14 @@ class ProductResource(AbstractProductResource): :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ - def __init__(self, db, metadata_type_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgis._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ self._db = db - self.metadata_type_resource = metadata_type_resource + self._index = index + self.metadata_type_resource = self._index.metadata_types self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) diff --git a/datacube/index/postgis/_users.py b/datacube/index/postgis/_users.py index 85a2ac9e34..ee7e7b4a3d 100644 --- a/datacube/index/postgis/_users.py +++ b/datacube/index/postgis/_users.py @@ -8,11 +8,12 @@ class UserResource(AbstractUserResource): - def __init__(self, db: PostGisDb) -> None: + def __init__(self, db: PostGisDb, index: "datacube.index.postgis.index.Index") -> None: """ :type db: datacube.drivers.postgis.PostGisDb """ self._db = db + self._index = index def grant_role(self, role: str, *usernames: str) -> None: """ diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index 9a016f547f..95fd7d5053 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -75,10 +75,10 @@ def __init__(self, db: PostGisDb) -> None: WARNING: Database schema and internal APIs may change significantly between releases. Use at your own risk.""") self._db = db - self._users = UserResource(db) - self._metadata_types = MetadataTypeResource(db) - self._products = ProductResource(db, self.metadata_types) - self._datasets = DatasetResource(db, self.products) + self._users = UserResource(db, self) + self._metadata_types = MetadataTypeResource(db, self) + self._products = ProductResource(db, self) + self._datasets = DatasetResource(db, self) @property def users(self) -> UserResource: diff --git a/datacube/index/postgres/_datasets.py b/datacube/index/postgres/_datasets.py index 127a241dc3..273fe040d8 100755 --- a/datacube/index/postgres/_datasets.py +++ b/datacube/index/postgres/_datasets.py @@ -35,13 +35,14 @@ class DatasetResource(AbstractDatasetResource): :type types: datacube.index._products.ProductResource """ - def __init__(self, db, dataset_type_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgres._connections.PostgresDb :type dataset_type_resource: datacube.index._products.ProductResource """ self._db = db - self.types = dataset_type_resource + self._index = index + self.types = self._index.products def get(self, id_: Union[str, UUID], include_sources=False): """ diff --git a/datacube/index/postgres/_metadata_types.py b/datacube/index/postgres/_metadata_types.py index dbf4045b51..ced04b0026 100644 --- a/datacube/index/postgres/_metadata_types.py +++ b/datacube/index/postgres/_metadata_types.py @@ -15,11 +15,12 @@ class MetadataTypeResource(AbstractMetadataTypeResource): - def __init__(self, db): + def __init__(self, db, index): """ :type db: datacube.drivers.postgres._connections.PostgresDb """ self._db = db + self._index = index self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) diff --git a/datacube/index/postgres/_products.py b/datacube/index/postgres/_products.py index 4bdd34635c..2af0745277 100644 --- a/datacube/index/postgres/_products.py +++ b/datacube/index/postgres/_products.py @@ -23,13 +23,14 @@ class ProductResource(AbstractProductResource): :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ - def __init__(self, db, metadata_type_resource): + def __init__(self, db, index): """ :type db: datacube.drivers.postgres._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource """ self._db = db - self.metadata_type_resource = metadata_type_resource + self._index = index + self.metadata_type_resource = self._index.metadata_types self.get_unsafe = lru_cache()(self.get_unsafe) self.get_by_name_unsafe = lru_cache()(self.get_by_name_unsafe) diff --git a/datacube/index/postgres/_users.py b/datacube/index/postgres/_users.py index f9cfd14152..1d3c1362ce 100644 --- a/datacube/index/postgres/_users.py +++ b/datacube/index/postgres/_users.py @@ -8,11 +8,12 @@ class UserResource(AbstractUserResource): - def __init__(self, db: PostgresDb) -> None: + def __init__(self, db: PostgresDb, index: "datacube.index.postgres.index.Index") -> None: """ :type db: datacube.drivers.postgres._connections.PostgresDb """ self._db = db + self._index = index def grant_role(self, role: str, *usernames: str) -> None: """ diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index 8f09f4b2de..d5f67715d0 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -64,10 +64,10 @@ class Index(AbstractIndex): def __init__(self, db: PostgresDb) -> None: self._db = db - self._users = UserResource(db) - self._metadata_types = MetadataTypeResource(db) - self._products = ProductResource(db, self.metadata_types) - self._datasets = DatasetResource(db, self.products) + self._users = UserResource(db, self) + self._metadata_types = MetadataTypeResource(db, self) + self._products = ProductResource(db, self) + self._datasets = DatasetResource(db, self) @property def users(self) -> UserResource: diff --git a/tests/index/test_api_index_dataset.py b/tests/index/test_api_index_dataset.py index 9951e69a16..786b0aa500 100644 --- a/tests/index/test_api_index_dataset.py +++ b/tests/index/test_api_index_dataset.py @@ -198,7 +198,7 @@ def insert_dataset_source(self, classifier, dataset_id, source_dataset_id): self.dataset_source.add((classifier, dataset_id, source_dataset_id)) -class MockTypesResource(object): +class MockTypesResource: def __init__(self, type_): self.type = type_ @@ -209,10 +209,16 @@ def get_by_name(self, *args, **kwargs): return self.type +class MockIndex: + def __init__(self, db, product): + self._db = db + self.products = MockTypesResource(product) + + def test_index_dataset(): mock_db = MockDb() - mock_types = MockTypesResource(_EXAMPLE_DATASET_TYPE) - datasets = DatasetResource(mock_db, mock_types) + mock_index = MockIndex(mock_db, _EXAMPLE_DATASET_TYPE) + datasets = DatasetResource(mock_db, mock_index) dataset = datasets.add(_EXAMPLE_NBAR_DATASET) ids = {d.id for d in mock_db.dataset.values()} @@ -237,8 +243,8 @@ def test_index_dataset(): def test_index_already_ingested_source_dataset(): mock_db = MockDb() - mock_types = MockTypesResource(_EXAMPLE_DATASET_TYPE) - datasets = DatasetResource(mock_db, mock_types) + mock_index = MockIndex(mock_db, _EXAMPLE_DATASET_TYPE) + datasets = DatasetResource(mock_db, mock_index) dataset = datasets.add(_EXAMPLE_NBAR_DATASET.sources['ortho']) assert len(mock_db.dataset) == 2 @@ -251,8 +257,8 @@ def test_index_already_ingested_source_dataset(): def test_index_two_levels_already_ingested(): mock_db = MockDb() - mock_types = MockTypesResource(_EXAMPLE_DATASET_TYPE) - datasets = DatasetResource(mock_db, mock_types) + mock_index = MockIndex(mock_db, _EXAMPLE_DATASET_TYPE) + datasets = DatasetResource(mock_db, mock_index) dataset = datasets.add(_EXAMPLE_NBAR_DATASET.sources['ortho'].sources['satellite_telemetry_data']) assert len(mock_db.dataset) == 1 From d88d4cbd734835b843ebc62d1d06ee6846e4a673 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Mon, 12 Sep 2022 16:57:34 +1000 Subject: [PATCH 07/18] Postgres index driver now honouring transaction API. No explicit testing yet. --- datacube/index/postgres/_datasets.py | 59 +++++++++++----------- datacube/index/postgres/_metadata_types.py | 18 ++++--- datacube/index/postgres/_products.py | 17 ++++--- datacube/index/postgres/_users.py | 13 ++--- tests/index/test_api_index_dataset.py | 3 ++ 5 files changed, 60 insertions(+), 50 deletions(-) diff --git a/datacube/index/postgres/_datasets.py b/datacube/index/postgres/_datasets.py index 273fe040d8..53807bfa4d 100755 --- a/datacube/index/postgres/_datasets.py +++ b/datacube/index/postgres/_datasets.py @@ -1,6 +1,6 @@ # This file is part of the Open Data Cube, see https://opendatacube.org for more information # -# Copyright (c) 2015-2020 ODC Contributors +# Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 """ API for dataset indexing, access and search. @@ -17,6 +17,7 @@ from datacube.drivers.postgres._fields import SimpleDocField, DateDocField from datacube.drivers.postgres._schema import DATASET from datacube.index.abstract import AbstractDatasetResource, DatasetSpatialMixin, DSID +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.model import Dataset, DatasetType from datacube.model.fields import Field from datacube.model.utils import flatten_datasets @@ -29,7 +30,7 @@ # It's a public api, so we can't reorganise old methods. # pylint: disable=too-many-public-methods, too-many-lines -class DatasetResource(AbstractDatasetResource): +class DatasetResource(AbstractDatasetResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgres._connections.PostgresDb :type types: datacube.index._products.ProductResource @@ -55,7 +56,7 @@ def get(self, id_: Union[str, UUID], include_sources=False): if isinstance(id_, str): id_ = UUID(id_) - with self._db.connect() as connection: + with self.db_connection() as connection: if not include_sources: dataset = connection.get_dataset(id_) return self._make(dataset, full_info=True) if dataset else None @@ -84,7 +85,7 @@ def to_uuid(x): ids = [to_uuid(i) for i in ids] - with self._db.connect() as connection: + with self.db_connection() as connection: rows = connection.get_datasets(ids) return [self._make(r, full_info=True) for r in rows] @@ -97,7 +98,7 @@ def get_derived(self, id_): """ if not isinstance(id_, UUID): id_ = UUID(id_) - with self._db.connect() as connection: + with self.db_connection() as connection: return [ self._make(result, full_info=True) for result in connection.get_derived_datasets(id_) @@ -110,7 +111,7 @@ def has(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: bool """ - with self._db.connect() as connection: + with self.db_connection() as connection: return connection.contains_dataset(id_) def bulk_has(self, ids_): @@ -123,7 +124,7 @@ def bulk_has(self, ids_): :rtype: [bool] """ - with self._db.connect() as connection: + with self.db_connection() as connection: existing = set(connection.datasets_intersection(ids_)) return [x in existing for x in @@ -183,7 +184,7 @@ def process_bunch(dss, main_ds, transaction): dss = [dataset] - with self._db.begin() as transaction: + with self.db_connection(transaction=True) as transaction: process_bunch(dss, dataset, transaction) return dataset @@ -208,7 +209,7 @@ def load_field(f: Union[str, fields.Field]) -> fields.Field: expressions = [product.metadata_type.dataset_fields.get('product') == product.name] - with self._db.connect() as connection: + with self.db_connection() as connection: for record in connection.get_duplicates(group_fields, expressions): dataset_ids = set(record[0]) grouped_fields = tuple(record[1:]) @@ -276,7 +277,7 @@ def update(self, dataset: Dataset, updates_allowed=None): _LOG.info("Updating dataset %s", dataset.id) product = self.types.get_by_name(dataset.type.name) - with self._db.begin() as transaction: + with self.db_connection(transaction=True) as transaction: if not transaction.update_dataset(dataset.metadata_doc_without_lineage(), dataset.id, product.id): raise ValueError("Failed to update dataset %s..." % dataset.id) @@ -295,7 +296,7 @@ def insert_one(uri, transaction): # front of a stack for uri in new_uris[::-1]: if transaction is None: - with self._db.begin() as tr: + with self.db_connection(transaction=True) as tr: insert_one(uri, tr) else: insert_one(uri, transaction) @@ -306,7 +307,7 @@ def archive(self, ids): :param Iterable[UUID] ids: list of dataset ids to archive """ - with self._db.begin() as transaction: + with self.db_connection(transaction=True) as transaction: for id_ in ids: transaction.archive_dataset(id_) @@ -316,7 +317,7 @@ def restore(self, ids): :param Iterable[UUID] ids: list of dataset ids to restore """ - with self._db.begin() as transaction: + with self.db_connection(transaction=True) as transaction: for id_ in ids: transaction.restore_dataset(id_) @@ -326,7 +327,7 @@ def purge(self, ids: Iterable[DSID]): :param ids: iterable of dataset ids to purge """ - with self._db.begin() as transaction: + with self.db_connection(transaction=True) as transaction: for id_ in ids: transaction.delete_dataset(id_) @@ -340,7 +341,7 @@ def get_all_dataset_ids(self, archived: bool): :param archived: :rtype: list[UUID] """ - with self._db.begin() as transaction: + with self.db_connection(transaction=True) as transaction: return [dsid[0] for dsid in transaction.all_dataset_ids(archived)] def get_field_names(self, product_name=None): @@ -367,7 +368,7 @@ def get_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self.db_connection() as connection: return connection.get_locations(id_) def get_archived_locations(self, id_): @@ -377,7 +378,7 @@ def get_archived_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self.db_connection() as connection: return [uri for uri, archived_dt in connection.get_archived_locations(id_)] def get_archived_location_times(self, id_): @@ -387,7 +388,7 @@ def get_archived_location_times(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: List[Tuple[str, datetime.datetime]] """ - with self._db.connect() as connection: + with self.db_connection() as connection: return list(connection.get_archived_locations(id_)) def add_location(self, id_, uri): @@ -402,7 +403,7 @@ def add_location(self, id_, uri): warnings.warn("Cannot add empty uri. (dataset %s)" % id_) return False - with self._db.connect() as connection: + with self.db_connection() as connection: return connection.insert_dataset_location(id_, uri) def get_datasets_for_location(self, uri, mode=None): @@ -413,7 +414,7 @@ def get_datasets_for_location(self, uri, mode=None): :param str mode: 'exact', 'prefix' or None (to guess) :return: """ - with self._db.connect() as connection: + with self.db_connection() as connection: return (self._make(row) for row in connection.get_datasets_for_location(uri, mode=mode)) def remove_location(self, id_, uri): @@ -424,7 +425,7 @@ def remove_location(self, id_, uri): :param str uri: fully qualified uri :returns bool: Was one removed? """ - with self._db.connect() as connection: + with self.db_connection() as connection: was_removed = connection.remove_location(id_, uri) return was_removed @@ -436,7 +437,7 @@ def archive_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be archived """ - with self._db.connect() as connection: + with self.db_connection() as connection: was_archived = connection.archive_location(id_, uri) return was_archived @@ -448,7 +449,7 @@ def restore_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be restored """ - with self._db.connect() as connection: + with self.db_connection() as connection: was_restored = connection.restore_location(id_, uri) return was_restored @@ -489,7 +490,7 @@ def search_by_metadata(self, metadata): :param dict metadata: :rtype: list[Dataset] """ - with self._db.connect() as connection: + with self.db_connection() as connection: for dataset in self._make_many(connection.search_datasets_by_metadata(metadata)): yield dataset @@ -645,7 +646,7 @@ def _do_search_by_product(self, query, return_fields=False, select_field_names=N else: select_fields = tuple(dataset_fields[field_name] for field_name in select_field_names) - with self._db.connect() as connection: + with self.db_connection() as connection: yield (product, connection.search_datasets( query_exprs, @@ -661,7 +662,7 @@ def _do_count_by_product(self, query): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self.db_connection() as connection: count = connection.count_datasets(query_exprs) if count > 0: yield product, count @@ -686,7 +687,7 @@ def _do_time_count(self, period, query, ensure_single=False): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self.db_connection() as connection: yield product, list(connection.count_datasets_through_time( start, end, @@ -731,7 +732,7 @@ def get_product_time_bounds(self, product: str): offset=max_offset, selection='greatest') - with self._db.connect() as connection: + with self.db_connection() as connection: result = connection.execute( select( [func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)] @@ -785,7 +786,7 @@ class DatasetLight(result_type, DatasetSpatialMixin): class DatasetLight(result_type): # type: ignore __slots__ = () - with self._db.connect() as connection: + with self.db_connection() as connection: results = connection.search_unique_datasets( query_exprs, select_fields=select_fields, diff --git a/datacube/index/postgres/_metadata_types.py b/datacube/index/postgres/_metadata_types.py index ced04b0026..f0fdc680f9 100644 --- a/datacube/index/postgres/_metadata_types.py +++ b/datacube/index/postgres/_metadata_types.py @@ -7,6 +7,7 @@ from cachetools.func import lru_cache from datacube.index.abstract import AbstractMetadataTypeResource +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.model import MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes @@ -14,7 +15,7 @@ _LOG = logging.getLogger(__name__) -class MetadataTypeResource(AbstractMetadataTypeResource): +class MetadataTypeResource(AbstractMetadataTypeResource, IndexResourceAddIn): def __init__(self, db, index): """ :type db: datacube.drivers.postgres._connections.PostgresDb @@ -52,7 +53,8 @@ def add(self, metadata_type, allow_table_lock=False): Allow an exclusive lock to be taken on the table while creating the indexes. This will halt other user's requests until completed. - If false, creation will be slightly slower and cannot be done in a transaction. + If false (and a transaction is not already active), creation will be slightly slower + and cannot be done in a transaction. :rtype: datacube.model.MetadataType """ # This column duplication is getting out of hand: @@ -68,7 +70,7 @@ def add(self, metadata_type, allow_table_lock=False): 'Metadata Type {}'.format(metadata_type.name) ) else: - with self._db.connect() as connection: + with self.db_connection(transaction=allow_table_lock) as connection: connection.insert_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -142,7 +144,7 @@ def update(self, metadata_type: MetadataType, allow_unsafe_updates=False, allow_ _LOG.info("Updating metadata type %s", metadata_type.name) - with self._db.connect() as connection: + with self.db_connection(transaction=allow_table_lock) as connection: connection.update_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -168,7 +170,7 @@ def update_document(self, definition, allow_unsafe_updates=False): # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self.db_connection() as connection: record = connection.get_metadata_type(id_) if record is None: raise KeyError('%s is not a valid MetadataType id') @@ -177,7 +179,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self.db_connection() as connection: record = connection.get_metadata_type_by_name(name) if not record: raise KeyError('%s is not a valid MetadataType name' % name) @@ -193,7 +195,7 @@ def check_field_indexes(self, allow_table_lock=False, If false, creation will be slightly slower and cannot be done in a transaction. """ - with self._db.connect() as connection: + with self.db_connection(transaction=allow_table_lock) as connection: connection.check_dynamic_fields( concurrently=not allow_table_lock, rebuild_indexes=rebuild_indexes, @@ -206,7 +208,7 @@ def get_all(self): :rtype: iter[datacube.model.MetadataType] """ - with self._db.connect() as connection: + with self.db_connection() as connection: return self._make_many(connection.get_all_metadata_types()) def _make_many(self, query_rows): diff --git a/datacube/index/postgres/_products.py b/datacube/index/postgres/_products.py index 2af0745277..ca55df37fc 100644 --- a/datacube/index/postgres/_products.py +++ b/datacube/index/postgres/_products.py @@ -8,16 +8,18 @@ from datacube.index import fields from datacube.index.abstract import AbstractProductResource +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.model import DatasetType, MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes from typing import Iterable, cast + _LOG = logging.getLogger(__name__) -class ProductResource(AbstractProductResource): +class ProductResource(AbstractProductResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgres._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource @@ -55,7 +57,8 @@ def add(self, product, allow_table_lock=False): Allow an exclusive lock to be taken on the table while creating the indexes. This will halt other user's requests until completed. - If false, creation will be slightly slower and cannot be done in a transaction. + If false (and there is no already active transaction), creation will be slightly slower + and cannot be done in a transaction. :param DatasetType product: Product to add :rtype: DatasetType """ @@ -75,7 +78,7 @@ def add(self, product, allow_table_lock=False): _LOG.warning('Adding metadata_type "%s" as it doesn\'t exist.', product.metadata_type.name) metadata_type = self.metadata_type_resource.add(product.metadata_type, allow_table_lock=allow_table_lock) - with self._db.connect() as connection: + with self.db_connection(transaction=allow_table_lock) as connection: connection.insert_product( name=product.name, metadata=product.metadata_doc, @@ -187,7 +190,7 @@ def update(self, product: DatasetType, allow_unsafe_updates=False, allow_table_l metadata_type = cast(MetadataType, self.metadata_type_resource.get_by_name(product.metadata_type.name)) # Given we cannot change metadata type because of the check above, and this is an # update method, the metadata type is guaranteed to already exist. - with self._db.connect() as conn: + with self.db_connection(transaction=allow_table_lock) as conn: conn.update_product( name=product.name, metadata=product.metadata_doc, @@ -225,7 +228,7 @@ def update_document(self, definition, allow_unsafe_updates=False, allow_table_lo # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self.db_connection() as connection: result = connection.get_product(id_) if not result: raise KeyError('"%s" is not a valid Product id' % id_) @@ -234,7 +237,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self.db_connection() as connection: result = connection.get_product_by_name(name) if not result: raise KeyError('"%s" is not a valid Product name' % name) @@ -306,7 +309,7 @@ def get_all(self) -> Iterable[DatasetType]: """ Retrieve all Products """ - with self._db.connect() as connection: + with self.db_connection() as connection: return (self._make(record) for record in connection.get_all_products()) def _make_many(self, query_rows): diff --git a/datacube/index/postgres/_users.py b/datacube/index/postgres/_users.py index 1d3c1362ce..073f1d5ced 100644 --- a/datacube/index/postgres/_users.py +++ b/datacube/index/postgres/_users.py @@ -1,13 +1,14 @@ # This file is part of the Open Data Cube, see https://opendatacube.org for more information # -# Copyright (c) 2015-2020 ODC Contributors +# Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 from typing import Iterable, Optional, Tuple from datacube.index.abstract import AbstractUserResource +from datacube.index.postgres._transaction import IndexResourceAddIn from datacube.drivers.postgres import PostgresDb -class UserResource(AbstractUserResource): +class UserResource(AbstractUserResource, IndexResourceAddIn): def __init__(self, db: PostgresDb, index: "datacube.index.postgres.index.Index") -> None: """ :type db: datacube.drivers.postgres._connections.PostgresDb @@ -19,7 +20,7 @@ def grant_role(self, role: str, *usernames: str) -> None: """ Grant a role to users """ - with self._db.connect() as connection: + with self.db_connection() as connection: connection.grant_role(role, usernames) def create_user(self, username: str, password: str, @@ -27,14 +28,14 @@ def create_user(self, username: str, password: str, """ Create a new user. """ - with self._db.connect() as connection: + with self.db_connection() as connection: connection.create_user(username, password, role, description=description) def delete_user(self, *usernames: str) -> None: """ Delete a user """ - with self._db.connect() as connection: + with self.db_connection() as connection: connection.drop_users(usernames) def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: @@ -42,6 +43,6 @@ def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: :return: list of (role, user, description) :rtype: list[(str, str, str)] """ - with self._db.connect() as connection: + with self.db_connection() as connection: for role, user, description in connection.list_users(): yield role, user, description diff --git a/tests/index/test_api_index_dataset.py b/tests/index/test_api_index_dataset.py index 786b0aa500..a8ff2e8aae 100644 --- a/tests/index/test_api_index_dataset.py +++ b/tests/index/test_api_index_dataset.py @@ -214,6 +214,9 @@ def __init__(self, db, product): self._db = db self.products = MockTypesResource(product) + def thread_transaction(self): + return None + def test_index_dataset(): mock_db = MockDb() From 49c71e1008df30c7437242a1f1495364f2179e29 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Wed, 14 Sep 2022 08:37:02 +1000 Subject: [PATCH 08/18] Oops - add missing file. --- datacube/index/postgres/_transaction.py | 68 +++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 datacube/index/postgres/_transaction.py diff --git a/datacube/index/postgres/_transaction.py b/datacube/index/postgres/_transaction.py new file mode 100644 index 0000000000..b2787525ee --- /dev/null +++ b/datacube/index/postgres/_transaction.py @@ -0,0 +1,68 @@ +# This file is part of the Open Data Cube, see https://opendatacube.org for more information +# +# Copyright (c) 2015-2022 ODC Contributors +# SPDX-License-Identifier: Apache-2.0 + + +from contextlib import contextmanager +from typing import Any + +from datacube.drivers.postgres import PostgresDb +from datacube.drivers.postgres._api import PostgresDbAPI +from datacube.index.abstract import AbstractTransaction + + +class PostgresTransaction(AbstractTransaction): + def __init__(self, db: PostgresDb, idx_id: str) -> None: + super().__init__(idx_id) + self._db = db + + def _new_connection(self) -> Any: + return self._db.begin() + + def _commit(self) -> None: + self._connection.commit() + + def _rollback(self) -> None: + self._connection.rollback() + + def _release_connection(self) -> None: + self._connection.close() + + +class IndexResourceAddIn: + @contextmanager + def db_connection(self, transaction=False) -> PostgresDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + In Resource Manager code replace self._db.connect() with self.db_connection(), and replace + self._db.begin() with self.db_connection(transaction=True). + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + trans = self._index.thread_transaction() + if trans is not None: + # Use active transaction + yield trans._connection + elif transaction: + with self._db.begin() as conn: + yield conn + else: + # Autocommit behaviour: + with self._db.connect() as conn: + yield conn + From 742504fcf97fbdcd1a728444e605f3f3fd03010c Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Tue, 20 Sep 2022 14:24:41 +1000 Subject: [PATCH 09/18] Tests and postgres driver implementation. --- datacube/index/abstract.py | 24 +++--- datacube/index/postgres/_transaction.py | 17 +++- datacube/index/postgres/index.py | 20 +---- integration_tests/conftest.py | 2 +- integration_tests/index/test_index_data.py | 99 +++++++++++++++++++++- 5 files changed, 125 insertions(+), 37 deletions(-) diff --git a/datacube/index/abstract.py b/datacube/index/abstract.py index 3e9442d80d..80bee46a1f 100644 --- a/datacube/index/abstract.py +++ b/datacube/index/abstract.py @@ -1000,8 +1000,8 @@ class AbstractTransaction(ABC): def __init__(self, index_id: str): self._connection: Any = None - self.tls_id = f"txn-{index_id}" - self.obj_lock = Lock() + self._tls_id = f"txn-{index_id}" + self._obj_lock = Lock() # Main Transaction API def begin(self) -> None: @@ -1012,7 +1012,7 @@ def begin(self) -> None: Calls implementation-specific _new_connection() method and manages thread local storage and locks. """ - with self.obj_lock: + with self._obj_lock: if self._connection is not None: raise ValueError("Cannot start a new transaction as one is already active") self._tls_stash() @@ -1025,7 +1025,7 @@ def commit(self) -> None: Calls implementation-specific _commit() method, and manages thread local storage and locks. """ - with self.obj_lock: + with self._obj_lock: if self._connection is None: raise ValueError("Cannot commit inactive transaction") self._commit() @@ -1041,7 +1041,7 @@ def rollback(self) -> None: Calls implementation-specific _rollback() method, and manages thread local storage and locks. """ - with self.obj_lock: + with self._obj_lock: if self._connection is None: raise ValueError("Cannot rollback inactive transaction") self._rollback() @@ -1062,15 +1062,15 @@ def _tls_stash(self) -> None: Check TLS is empty, create a new connection and stash it. :return: """ - stored_val = thread_local_cache(self.tls_id) + stored_val = thread_local_cache(self._tls_id) if stored_val is not None: raise ValueError("Cannot start a new transaction as one is already active for this thread") self._connection = self._new_connection() - thread_local_cache(self.tls_id, purge=True) - thread_local_cache(self.tls_id, self) + thread_local_cache(self._tls_id, purge=True) + thread_local_cache(self._tls_id, self) def _tls_purge(self) -> None: - thread_local_cache(self.tls_id, purge=True) + thread_local_cache(self._tls_id, purge=True) # Commit/Rollback exceptions for Context Manager usage patterns def commit_exception(self, errmsg: str) -> TransactionException: @@ -1082,9 +1082,13 @@ def rollback_exception(self, errmsg: str) -> TransactionException: # Context Manager Interface def __enter__(self): self.begin() + return self def __exit__(self, exc_type, exc_value, traceback): - if issubclass(exc_type, TransactionException): + if not self.active: + # User has already manually committed or rolled back. + return True + if exc_type is not None and issubclass(exc_type, TransactionException): # User raised a TransactionException, Commit or rollback as per exception if exc_value.commit: self.commit() diff --git a/datacube/index/postgres/_transaction.py b/datacube/index/postgres/_transaction.py index b2787525ee..a8d6bcfc10 100644 --- a/datacube/index/postgres/_transaction.py +++ b/datacube/index/postgres/_transaction.py @@ -3,14 +3,17 @@ # Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 +import logging from contextlib import contextmanager +from sqlalchemy import text from typing import Any from datacube.drivers.postgres import PostgresDb from datacube.drivers.postgres._api import PostgresDbAPI from datacube.index.abstract import AbstractTransaction +_LOG = logging.getLogger(__name__) class PostgresTransaction(AbstractTransaction): def __init__(self, db: PostgresDb, idx_id: str) -> None: @@ -18,7 +21,10 @@ def __init__(self, db: PostgresDb, idx_id: str) -> None: self._db = db def _new_connection(self) -> Any: - return self._db.begin() + dbconn = self._db.give_me_a_connection() + dbconn.execute(text('BEGIN')) + conn = PostgresDbAPI(dbconn) + return conn def _commit(self) -> None: self._connection.commit() @@ -27,12 +33,13 @@ def _rollback(self) -> None: self._connection.rollback() def _release_connection(self) -> None: - self._connection.close() + self._connection._connection.close() + self._connection._connection = None class IndexResourceAddIn: @contextmanager - def db_connection(self, transaction=False) -> PostgresDbAPI: + def db_connection(self, transaction: bool = False) -> PostgresDbAPI: """ Context manager representing a database connection. @@ -55,14 +62,16 @@ def db_connection(self, transaction=False) -> PostgresDbAPI: :return: A PostgresDbAPI object, with the specified transaction semantics. """ trans = self._index.thread_transaction() + closing = False if trans is not None: # Use active transaction yield trans._connection elif transaction: + closing = True with self._db.begin() as conn: yield conn else: + closing = True # Autocommit behaviour: with self._db.connect() as conn: yield conn - diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index d5f67715d0..48be6456c5 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -3,9 +3,11 @@ # Copyright (c) 2015-2020 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging +from contextlib import contextmanager from typing import Any from datacube.drivers.postgres import PostgresDb +from datacube.index.postgres._transaction import PostgresTransaction from datacube.index.postgres._datasets import DatasetResource # type: ignore from datacube.index.postgres._metadata_types import MetadataTypeResource from datacube.index.postgres._products import ProductResource @@ -17,24 +19,6 @@ _LOG = logging.getLogger(__name__) -class PostgresTransaction(AbstractTransaction): - def __init__(self, db: PostgresDb, idx_id: str) -> None: - super().__init__(idx_id) - self._db = db - - def _new_connection(self) -> Any: - return self._db.begin() - - def _commit(self) -> None: - self._connection.commit() - - def _rollback(self) -> None: - self._connection.rollback() - - def _release_connection(self) -> None: - self._connection.close() - - class Index(AbstractIndex): """ Access to the datacube index. diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index cbd45619b3..e59e76f80a 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -379,7 +379,7 @@ def index_empty(local_config, uninitialised_postgres_db: PostgresDb): @pytest.fixture def initialised_postgres_db(index): """ - Return a connection to an PostgreSQL database, initialised with the default schema + Return a connection to an PostgreSQL (or Postgis) database, initialised with the default schema and tables. """ return index._db diff --git a/integration_tests/index/test_index_data.py b/integration_tests/index/test_index_data.py index a41bc01fc0..dc341b9c56 100755 --- a/integration_tests/index/test_index_data.py +++ b/integration_tests/index/test_index_data.py @@ -273,10 +273,10 @@ def test_get_dataset(index: Index, telemetry_dataset: Dataset) -> None: 'f226a278-e422-11e6-b501-185e0f80a5c1']) == [] -def test_transactions(index: Index, - initialised_postgres_db: PostgresDb, - local_config, - default_metadata_type) -> None: +def test_transactions_internal_api(index: Index, + initialised_postgres_db: PostgresDb, + local_config, + default_metadata_type) -> None: assert not index.datasets.has(_telemetry_uuid) dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) @@ -297,6 +297,97 @@ def test_transactions(index: Index, assert not index.datasets.has(_telemetry_uuid) +def test_transactions_api_ctx_mgr(index, + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): + from datacube.index.hl import Doc2Dataset + import logging + _LOG = logging.getLogger(__name__) + resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) + ds1, err = resolver(*eo3_ls8_dataset_doc) + ds2, err = resolver(*eo3_ls8_dataset2_doc) + with pytest.raises(Exception) as e: + with index.transaction() as trans: + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + raise Exception("Rollback!") + assert "Rollback!" in str(e.value) + assert index.datasets.get(ds1.id) is None + with index.transaction() as trans: + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds1.id) is not None + with index.transaction() as trans: + index.datasets.add(ds2) + assert index.datasets.get(ds2.id) is not None + raise trans.rollback_exception("Rollback") + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is None + + +def test_transactions_api_manual(index, + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): + from datacube.index.hl import Doc2Dataset + import logging + _LOG = logging.getLogger(__name__) + resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) + ds1, err = resolver(*eo3_ls8_dataset_doc) + ds2, err = resolver(*eo3_ls8_dataset2_doc) + trans = index.transaction() + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + trans.begin() + index.datasets.add(ds2) + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is not None + trans.rollback() + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is None + trans.begin() + index.datasets.add(ds2) + trans.commit() + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is not None + + +def test_transactions_api_hybrid(index, + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): + from datacube.index.hl import Doc2Dataset + import logging + _LOG = logging.getLogger(__name__) + resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) + ds1, err = resolver(*eo3_ls8_dataset_doc) + ds2, err = resolver(*eo3_ls8_dataset2_doc) + with index.transaction() as trans: + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + trans.rollback() + assert index.datasets.get(ds1.id) is None + trans.begin() + assert index.datasets.get(ds1.id) is None + index.datasets.add(ds1) + assert index.datasets.get(ds1.id) is not None + trans.commit() + assert index.datasets.get(ds1.id) is not None + trans.begin() + index.datasets.add(ds2) + assert index.datasets.get(ds2.id) is not None + trans.rollback() + assert index.datasets.get(ds1.id) is not None + assert index.datasets.get(ds2.id) is None + + def test_get_missing_things(index: Index) -> None: """ The get(id) methods should return None if the object doesn't exist. From 522db9148716ce2275046319d1931bc941e13b65 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Tue, 20 Sep 2022 14:28:52 +1000 Subject: [PATCH 10/18] lint format --- integration_tests/index/test_index_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration_tests/index/test_index_data.py b/integration_tests/index/test_index_data.py index dc341b9c56..d448d8f312 100755 --- a/integration_tests/index/test_index_data.py +++ b/integration_tests/index/test_index_data.py @@ -330,10 +330,10 @@ def test_transactions_api_ctx_mgr(index, def test_transactions_api_manual(index, - extended_eo3_metadata_type_doc, - ls8_eo3_product, - eo3_ls8_dataset_doc, - eo3_ls8_dataset2_doc): + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): from datacube.index.hl import Doc2Dataset import logging _LOG = logging.getLogger(__name__) From 10c1b5c21b5bbd64e4206f4099d44c9b03d75b56 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Wed, 21 Sep 2022 10:42:31 +1000 Subject: [PATCH 11/18] Postgis implementation (plus bug fixes in postgres). --- datacube/index/postgis/_datasets.py | 57 ++++++++-------- datacube/index/postgis/_metadata_types.py | 15 +++-- datacube/index/postgis/_products.py | 13 ++-- datacube/index/postgis/_transaction.py | 77 ++++++++++++++++++++++ datacube/index/postgis/_users.py | 11 ++-- datacube/index/postgis/index.py | 19 +----- datacube/index/postgres/_datasets.py | 54 +++++++-------- datacube/index/postgres/_metadata_types.py | 12 ++-- datacube/index/postgres/_products.py | 10 +-- datacube/index/postgres/_transaction.py | 2 +- datacube/index/postgres/_users.py | 8 +-- integration_tests/index/test_index_data.py | 8 +-- 12 files changed, 175 insertions(+), 111 deletions(-) create mode 100644 datacube/index/postgis/_transaction.py diff --git a/datacube/index/postgis/_datasets.py b/datacube/index/postgis/_datasets.py index 5903c5f242..ec2b50f69d 100755 --- a/datacube/index/postgis/_datasets.py +++ b/datacube/index/postgis/_datasets.py @@ -17,6 +17,7 @@ from datacube.drivers.postgis._fields import SimpleDocField, DateDocField from datacube.drivers.postgis._schema import Dataset as SQLDataset from datacube.index.abstract import AbstractDatasetResource, DatasetSpatialMixin, DSID +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.model import Dataset, Product from datacube.model.fields import Field from datacube.model.utils import flatten_datasets @@ -32,7 +33,7 @@ # pylint: disable=too-many-public-methods, too-many-lines -class DatasetResource(AbstractDatasetResource): +class DatasetResource(AbstractDatasetResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgis._connections.PostgresDb :type types: datacube.index._products.ProductResource @@ -59,7 +60,7 @@ def get(self, id_: Union[str, UUID], include_sources=False): if isinstance(id_, str): id_ = UUID(id_) - with self._db.connect() as connection: + with self._db_connection() as connection: if not include_sources: dataset = connection.get_dataset(id_) return self._make(dataset, full_info=True) if dataset else None @@ -88,7 +89,7 @@ def to_uuid(x): ids = [to_uuid(i) for i in ids] - with self._db.connect() as connection: + with self._db_connection() as connection: rows = connection.get_datasets(ids) return [self._make(r, full_info=True) for r in rows] @@ -101,7 +102,7 @@ def get_derived(self, id_): """ if not isinstance(id_, UUID): id_ = UUID(id_) - with self._db.connect() as connection: + with self._db_connection() as connection: return [ self._make(result, full_info=True) for result in connection.get_derived_datasets(id_) @@ -114,7 +115,7 @@ def has(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: bool """ - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.contains_dataset(id_) def bulk_has(self, ids_): @@ -127,7 +128,7 @@ def bulk_has(self, ids_): :rtype: [bool] """ - with self._db.connect() as connection: + with self._db_connection() as connection: existing = set(connection.datasets_intersection(ids_)) return [x in existing for x in @@ -194,7 +195,7 @@ def process_bunch(dss, main_ds, transaction): dss = [dataset] - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: process_bunch(dss, dataset, transaction) return dataset @@ -219,7 +220,7 @@ def load_field(f: Union[str, fields.Field]) -> fields.Field: expressions = [product.metadata_type.dataset_fields.get('product') == product.name] - with self._db.connect() as connection: + with self._db_connection() as connection: for record in connection.get_duplicates(group_fields, expressions): dataset_ids = set(record[0]) grouped_fields = tuple(record[1:]) @@ -289,7 +290,7 @@ def update(self, dataset: Dataset, updates_allowed=None): _LOG.info("Updating dataset %s", dataset.id) product = self.types.get_by_name(dataset.type.name) - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: if not transaction.update_dataset(dataset.metadata_doc_without_lineage(), dataset.id, product.id): raise ValueError("Failed to update dataset %s..." % dataset.id) @@ -308,7 +309,7 @@ def insert_one(uri, transaction): # front of a stack for uri in new_uris[::-1]: if transaction is None: - with self._db.begin() as tr: + with self._db_connection(transaction=True) as tr: insert_one(uri, tr) else: insert_one(uri, transaction) @@ -319,7 +320,7 @@ def archive(self, ids): :param Iterable[UUID] ids: list of dataset ids to archive """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.archive_dataset(id_) @@ -329,7 +330,7 @@ def restore(self, ids): :param Iterable[UUID] ids: list of dataset ids to restore """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.restore_dataset(id_) @@ -339,7 +340,7 @@ def purge(self, ids: Iterable[DSID]): :param ids: iterable of dataset ids to purge """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.delete_dataset(id_) @@ -353,7 +354,7 @@ def get_all_dataset_ids(self, archived: bool): :param archived: :rtype: list[UUID] """ - with self._db.begin() as transaction: + with self._db_connection(transaction=True) as transaction: return [dsid[0] for dsid in transaction.all_dataset_ids(archived)] def get_field_names(self, product_name=None): @@ -380,7 +381,7 @@ def get_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.get_locations(id_) def get_archived_locations(self, id_): @@ -390,7 +391,7 @@ def get_archived_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return [uri for uri, archived_dt in connection.get_archived_locations(id_)] def get_archived_location_times(self, id_): @@ -400,7 +401,7 @@ def get_archived_location_times(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: List[Tuple[str, datetime.datetime]] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return list(connection.get_archived_locations(id_)) def add_location(self, id_, uri): @@ -415,7 +416,7 @@ def add_location(self, id_, uri): warnings.warn("Cannot add empty uri. (dataset %s)" % id_) return False - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.insert_dataset_location(id_, uri) def get_datasets_for_location(self, uri, mode=None): @@ -426,7 +427,7 @@ def get_datasets_for_location(self, uri, mode=None): :param str mode: 'exact', 'prefix' or None (to guess) :return: """ - with self._db.connect() as connection: + with self._db_connection() as connection: return (self._make(row) for row in connection.get_datasets_for_location(uri, mode=mode)) def remove_location(self, id_, uri): @@ -437,7 +438,7 @@ def remove_location(self, id_, uri): :param str uri: fully qualified uri :returns bool: Was one removed? """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_removed = connection.remove_location(id_, uri) return was_removed @@ -449,7 +450,7 @@ def archive_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be archived """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_archived = connection.archive_location(id_, uri) return was_archived @@ -461,7 +462,7 @@ def restore_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be restored """ - with self._db.connect() as connection: + with self._db_connection() as connection: was_restored = connection.restore_location(id_, uri) return was_restored @@ -502,7 +503,7 @@ def search_by_metadata(self, metadata): :param dict metadata: :rtype: list[Dataset] """ - with self._db.connect() as connection: + with self._db_connection() as connection: for dataset in self._make_many(connection.search_datasets_by_metadata(metadata)): yield dataset @@ -651,7 +652,7 @@ def _do_search_by_product(self, query, return_fields=False, select_field_names=N else: select_fields = tuple(dataset_fields[field_name] for field_name in select_field_names) - with self._db.connect() as connection: + with self._db_connection() as connection: yield (product, connection.search_datasets( query_exprs, @@ -667,7 +668,7 @@ def _do_count_by_product(self, query): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self._db_connection() as connection: count = connection.count_datasets(query_exprs) if count > 0: yield product, count @@ -692,7 +693,7 @@ def _do_time_count(self, period, query, ensure_single=False): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self._db.connect() as connection: + with self._db_connection() as connection: yield product, list(connection.count_datasets_through_time( start, end, @@ -739,7 +740,7 @@ def get_product_time_bounds(self, product: str): offset=max_offset, selection='greatest') - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.execute( select( [func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)] @@ -793,7 +794,7 @@ class DatasetLight(result_type, DatasetSpatialMixin): class DatasetLight(result_type): # type: ignore __slots__ = () - with self._db.connect() as connection: + with self._db_connection() as connection: results = connection.search_unique_datasets( query_exprs, select_fields=select_fields, diff --git a/datacube/index/postgis/_metadata_types.py b/datacube/index/postgis/_metadata_types.py index 809f95c3b8..e01afdde87 100644 --- a/datacube/index/postgis/_metadata_types.py +++ b/datacube/index/postgis/_metadata_types.py @@ -7,6 +7,7 @@ from cachetools.func import lru_cache from datacube.index.abstract import AbstractMetadataTypeResource +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.model import MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes @@ -14,7 +15,7 @@ _LOG = logging.getLogger(__name__) -class MetadataTypeResource(AbstractMetadataTypeResource): +class MetadataTypeResource(AbstractMetadataTypeResource, IndexResourceAddIn): def __init__(self, db, index): """ :type db: datacube.drivers.postgis._connections.PostgresDb @@ -68,7 +69,7 @@ def add(self, metadata_type, allow_table_lock=False): 'Metadata Type {}'.format(metadata_type.name) ) else: - with self._db.connect() as connection: + with self._db_connection() as connection: connection.insert_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -142,7 +143,7 @@ def update(self, metadata_type: MetadataType, allow_unsafe_updates=False, allow_ _LOG.info("Updating metadata type %s", metadata_type.name) - with self._db.connect() as connection: + with self._db_connection() as connection: connection.update_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -168,7 +169,7 @@ def update_document(self, definition, allow_unsafe_updates=False): # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type(id_) if record is None: raise KeyError('%s is not a valid MetadataType id') @@ -177,7 +178,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type_by_name(name) if not record: raise KeyError('%s is not a valid MetadataType name' % name) @@ -193,7 +194,7 @@ def check_field_indexes(self, allow_table_lock=False, If false, creation will be slightly slower and cannot be done in a transaction. """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.check_dynamic_fields( concurrently=not allow_table_lock, rebuild_indexes=rebuild_indexes, @@ -206,7 +207,7 @@ def get_all(self): :rtype: iter[datacube.model.MetadataType] """ - with self._db.connect() as connection: + with self._db_connection() as connection: return self._make_many(connection.get_all_metadata_types()) def _make_many(self, query_rows): diff --git a/datacube/index/postgis/_products.py b/datacube/index/postgis/_products.py index 3ad45b5f3c..e94d306e84 100644 --- a/datacube/index/postgis/_products.py +++ b/datacube/index/postgis/_products.py @@ -8,6 +8,7 @@ from datacube.index import fields from datacube.index.abstract import AbstractProductResource +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.model import Product, MetadataType from datacube.utils import jsonify_document, changes, _readable_offset from datacube.utils.changes import check_doc_unchanged, get_doc_changes @@ -17,7 +18,7 @@ _LOG = logging.getLogger(__name__) -class ProductResource(AbstractProductResource): +class ProductResource(AbstractProductResource, IndexResourceAddIn): """ :type _db: datacube.drivers.postgis._connections.PostgresDb :type metadata_type_resource: datacube.index._metadata_types.MetadataTypeResource @@ -75,7 +76,7 @@ def add(self, product, allow_table_lock=False): _LOG.warning('Adding metadata_type "%s" as it doesn\'t exist.', product.metadata_type.name) metadata_type = self.metadata_type_resource.add(product.metadata_type, allow_table_lock=allow_table_lock) - with self._db.connect() as connection: + with self._db_connection() as connection: connection.insert_product( name=product.name, metadata=product.metadata_doc, @@ -184,7 +185,7 @@ def update(self, product: Product, allow_unsafe_updates=False, allow_table_lock= metadata_type = self.metadata_type_resource.get_by_name(product.metadata_type.name) # TODO: should we add metadata type here? assert metadata_type, "TODO: should we add metadata type here?" - with self._db.connect() as conn: + with self._db_connection() as conn: conn.update_product( name=product.name, metadata=product.metadata_doc, @@ -222,7 +223,7 @@ def update_document(self, definition, allow_unsafe_updates=False, allow_table_lo # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.get_product(id_) if not result: raise KeyError('"%s" is not a valid Product id' % id_) @@ -231,7 +232,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self._db.connect() as connection: + with self._db_connection() as connection: result = connection.get_product_by_name(name) if not result: raise KeyError('"%s" is not a valid Product name' % name) @@ -306,7 +307,7 @@ def get_all(self) -> Iterable[Product]: """ Retrieve all Products """ - with self._db.connect() as connection: + with self._db_connection() as connection: return (self._make(record) for record in connection.get_all_products()) def _make_many(self, query_rows): diff --git a/datacube/index/postgis/_transaction.py b/datacube/index/postgis/_transaction.py new file mode 100644 index 0000000000..130fc6fe8c --- /dev/null +++ b/datacube/index/postgis/_transaction.py @@ -0,0 +1,77 @@ +# This file is part of the Open Data Cube, see https://opendatacube.org for more information +# +# Copyright (c) 2015-2022 ODC Contributors +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from contextlib import contextmanager +from sqlalchemy import text +from typing import Any + +from datacube.drivers.postgis import PostGisDb +from datacube.drivers.postgis._api import PostgisDbAPI +from datacube.index.abstract import AbstractTransaction + +_LOG = logging.getLogger(__name__) + +class PostgisTransaction(AbstractTransaction): + def __init__(self, db: PostGisDb, idx_id: str) -> None: + super().__init__(idx_id) + self._db = db + + def _new_connection(self) -> Any: + dbconn = self._db.give_me_a_connection() + dbconn.execute(text('BEGIN')) + conn = PostgisDbAPI(self._db, dbconn) + return conn + + def _commit(self) -> None: + self._connection.commit() + + def _rollback(self) -> None: + self._connection.rollback() + + def _release_connection(self) -> None: + self._connection._connection.close() + self._connection._connection = None + + +class IndexResourceAddIn: + @contextmanager + def _db_connection(self, transaction: bool = False) -> PostgisDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + In Resource Manager code replace self._db.connect() with self.db_connection(), and replace + self._db.begin() with self.db_connection(transaction=True). + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + trans = self._index.thread_transaction() + closing = False + if trans is not None: + # Use active transaction + yield trans._connection + elif transaction: + closing = True + with self._db.begin() as conn: + yield conn + else: + closing = True + # Autocommit behaviour: + with self._db.connect() as conn: + yield conn diff --git a/datacube/index/postgis/_users.py b/datacube/index/postgis/_users.py index ee7e7b4a3d..1acb9a7f77 100644 --- a/datacube/index/postgis/_users.py +++ b/datacube/index/postgis/_users.py @@ -4,10 +4,11 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Iterable, Optional, Tuple from datacube.index.abstract import AbstractUserResource +from datacube.index.postgis._transaction import IndexResourceAddIn from datacube.drivers.postgis import PostGisDb -class UserResource(AbstractUserResource): +class UserResource(AbstractUserResource, IndexResourceAddIn): def __init__(self, db: PostGisDb, index: "datacube.index.postgis.index.Index") -> None: """ :type db: datacube.drivers.postgis.PostGisDb @@ -19,7 +20,7 @@ def grant_role(self, role: str, *usernames: str) -> None: """ Grant a role to users """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.grant_role(role, usernames) def create_user(self, username: str, password: str, @@ -27,14 +28,14 @@ def create_user(self, username: str, password: str, """ Create a new user. """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.create_user(username, password, role, description=description) def delete_user(self, *usernames: str) -> None: """ Delete a user """ - with self._db.connect() as connection: + with self._db_connection() as connection: connection.drop_users(usernames) def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: @@ -42,6 +43,6 @@ def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: :return: list of (role, user, description) :rtype: list[(str, str, str)] """ - with self._db.connect() as connection: + with self._db_connection() as connection: for role, user, description in connection.list_users(): yield role, user, description diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index 95fd7d5053..3c4d82aea4 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -6,6 +6,7 @@ from typing import Any, Iterable, Sequence from datacube.drivers.postgis import PostGisDb +from datacube.index.postgis._transaction import PostgisTransaction from datacube.index.postgis._datasets import DatasetResource, DSID # type: ignore from datacube.index.postgis._metadata_types import MetadataTypeResource from datacube.index.postgis._products import ProductResource @@ -17,24 +18,6 @@ _LOG = logging.getLogger(__name__) -class PostgisTransaction(AbstractTransaction): - def __init__(self, db: PostGisDb, idx_id: str) -> None: - super().__init__(idx_id) - self._db = db - - def _new_connection(self) -> Any: - return self._db.begin() - - def _commit(self) -> None: - self._connection.commit() - - def _rollback(self) -> None: - self._connection.rollback() - - def _release_connection(self) -> None: - self._connection.close() - - class Index(AbstractIndex): """ Access to the datacube index. diff --git a/datacube/index/postgres/_datasets.py b/datacube/index/postgres/_datasets.py index 53807bfa4d..dbb198dfa1 100755 --- a/datacube/index/postgres/_datasets.py +++ b/datacube/index/postgres/_datasets.py @@ -56,7 +56,7 @@ def get(self, id_: Union[str, UUID], include_sources=False): if isinstance(id_, str): id_ = UUID(id_) - with self.db_connection() as connection: + with self._db_connection() as connection: if not include_sources: dataset = connection.get_dataset(id_) return self._make(dataset, full_info=True) if dataset else None @@ -85,7 +85,7 @@ def to_uuid(x): ids = [to_uuid(i) for i in ids] - with self.db_connection() as connection: + with self._db_connection() as connection: rows = connection.get_datasets(ids) return [self._make(r, full_info=True) for r in rows] @@ -98,7 +98,7 @@ def get_derived(self, id_): """ if not isinstance(id_, UUID): id_ = UUID(id_) - with self.db_connection() as connection: + with self._db_connection() as connection: return [ self._make(result, full_info=True) for result in connection.get_derived_datasets(id_) @@ -111,7 +111,7 @@ def has(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: bool """ - with self.db_connection() as connection: + with self._db_connection() as connection: return connection.contains_dataset(id_) def bulk_has(self, ids_): @@ -124,7 +124,7 @@ def bulk_has(self, ids_): :rtype: [bool] """ - with self.db_connection() as connection: + with self._db_connection() as connection: existing = set(connection.datasets_intersection(ids_)) return [x in existing for x in @@ -184,7 +184,7 @@ def process_bunch(dss, main_ds, transaction): dss = [dataset] - with self.db_connection(transaction=True) as transaction: + with self._db_connection(transaction=True) as transaction: process_bunch(dss, dataset, transaction) return dataset @@ -209,7 +209,7 @@ def load_field(f: Union[str, fields.Field]) -> fields.Field: expressions = [product.metadata_type.dataset_fields.get('product') == product.name] - with self.db_connection() as connection: + with self._db_connection() as connection: for record in connection.get_duplicates(group_fields, expressions): dataset_ids = set(record[0]) grouped_fields = tuple(record[1:]) @@ -277,7 +277,7 @@ def update(self, dataset: Dataset, updates_allowed=None): _LOG.info("Updating dataset %s", dataset.id) product = self.types.get_by_name(dataset.type.name) - with self.db_connection(transaction=True) as transaction: + with self._db_connection(transaction=True) as transaction: if not transaction.update_dataset(dataset.metadata_doc_without_lineage(), dataset.id, product.id): raise ValueError("Failed to update dataset %s..." % dataset.id) @@ -296,7 +296,7 @@ def insert_one(uri, transaction): # front of a stack for uri in new_uris[::-1]: if transaction is None: - with self.db_connection(transaction=True) as tr: + with self._db_connection(transaction=True) as tr: insert_one(uri, tr) else: insert_one(uri, transaction) @@ -307,7 +307,7 @@ def archive(self, ids): :param Iterable[UUID] ids: list of dataset ids to archive """ - with self.db_connection(transaction=True) as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.archive_dataset(id_) @@ -317,7 +317,7 @@ def restore(self, ids): :param Iterable[UUID] ids: list of dataset ids to restore """ - with self.db_connection(transaction=True) as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.restore_dataset(id_) @@ -327,7 +327,7 @@ def purge(self, ids: Iterable[DSID]): :param ids: iterable of dataset ids to purge """ - with self.db_connection(transaction=True) as transaction: + with self._db_connection(transaction=True) as transaction: for id_ in ids: transaction.delete_dataset(id_) @@ -341,7 +341,7 @@ def get_all_dataset_ids(self, archived: bool): :param archived: :rtype: list[UUID] """ - with self.db_connection(transaction=True) as transaction: + with self._db_connection(transaction=True) as transaction: return [dsid[0] for dsid in transaction.all_dataset_ids(archived)] def get_field_names(self, product_name=None): @@ -368,7 +368,7 @@ def get_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self.db_connection() as connection: + with self._db_connection() as connection: return connection.get_locations(id_) def get_archived_locations(self, id_): @@ -378,7 +378,7 @@ def get_archived_locations(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: list[str] """ - with self.db_connection() as connection: + with self._db_connection() as connection: return [uri for uri, archived_dt in connection.get_archived_locations(id_)] def get_archived_location_times(self, id_): @@ -388,7 +388,7 @@ def get_archived_location_times(self, id_): :param typing.Union[UUID, str] id_: dataset id :rtype: List[Tuple[str, datetime.datetime]] """ - with self.db_connection() as connection: + with self._db_connection() as connection: return list(connection.get_archived_locations(id_)) def add_location(self, id_, uri): @@ -403,7 +403,7 @@ def add_location(self, id_, uri): warnings.warn("Cannot add empty uri. (dataset %s)" % id_) return False - with self.db_connection() as connection: + with self._db_connection() as connection: return connection.insert_dataset_location(id_, uri) def get_datasets_for_location(self, uri, mode=None): @@ -414,7 +414,7 @@ def get_datasets_for_location(self, uri, mode=None): :param str mode: 'exact', 'prefix' or None (to guess) :return: """ - with self.db_connection() as connection: + with self._db_connection() as connection: return (self._make(row) for row in connection.get_datasets_for_location(uri, mode=mode)) def remove_location(self, id_, uri): @@ -425,7 +425,7 @@ def remove_location(self, id_, uri): :param str uri: fully qualified uri :returns bool: Was one removed? """ - with self.db_connection() as connection: + with self._db_connection() as connection: was_removed = connection.remove_location(id_, uri) return was_removed @@ -437,7 +437,7 @@ def archive_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be archived """ - with self.db_connection() as connection: + with self._db_connection() as connection: was_archived = connection.archive_location(id_, uri) return was_archived @@ -449,7 +449,7 @@ def restore_location(self, id_, uri): :param str uri: fully qualified uri :return bool: location was able to be restored """ - with self.db_connection() as connection: + with self._db_connection() as connection: was_restored = connection.restore_location(id_, uri) return was_restored @@ -490,7 +490,7 @@ def search_by_metadata(self, metadata): :param dict metadata: :rtype: list[Dataset] """ - with self.db_connection() as connection: + with self._db_connection() as connection: for dataset in self._make_many(connection.search_datasets_by_metadata(metadata)): yield dataset @@ -646,7 +646,7 @@ def _do_search_by_product(self, query, return_fields=False, select_field_names=N else: select_fields = tuple(dataset_fields[field_name] for field_name in select_field_names) - with self.db_connection() as connection: + with self._db_connection() as connection: yield (product, connection.search_datasets( query_exprs, @@ -662,7 +662,7 @@ def _do_count_by_product(self, query): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self.db_connection() as connection: + with self._db_connection() as connection: count = connection.count_datasets(query_exprs) if count > 0: yield product, count @@ -687,7 +687,7 @@ def _do_time_count(self, period, query, ensure_single=False): for q, product in product_queries: dataset_fields = product.metadata_type.dataset_fields query_exprs = tuple(fields.to_expressions(dataset_fields.get, **q)) - with self.db_connection() as connection: + with self._db_connection() as connection: yield product, list(connection.count_datasets_through_time( start, end, @@ -732,7 +732,7 @@ def get_product_time_bounds(self, product: str): offset=max_offset, selection='greatest') - with self.db_connection() as connection: + with self._db_connection() as connection: result = connection.execute( select( [func.min(time_min.alchemy_expression), func.max(time_max.alchemy_expression)] @@ -786,7 +786,7 @@ class DatasetLight(result_type, DatasetSpatialMixin): class DatasetLight(result_type): # type: ignore __slots__ = () - with self.db_connection() as connection: + with self._db_connection() as connection: results = connection.search_unique_datasets( query_exprs, select_fields=select_fields, diff --git a/datacube/index/postgres/_metadata_types.py b/datacube/index/postgres/_metadata_types.py index f0fdc680f9..feef1935fe 100644 --- a/datacube/index/postgres/_metadata_types.py +++ b/datacube/index/postgres/_metadata_types.py @@ -70,7 +70,7 @@ def add(self, metadata_type, allow_table_lock=False): 'Metadata Type {}'.format(metadata_type.name) ) else: - with self.db_connection(transaction=allow_table_lock) as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.insert_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -144,7 +144,7 @@ def update(self, metadata_type: MetadataType, allow_unsafe_updates=False, allow_ _LOG.info("Updating metadata type %s", metadata_type.name) - with self.db_connection(transaction=allow_table_lock) as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.update_metadata_type( name=metadata_type.name, definition=metadata_type.definition, @@ -170,7 +170,7 @@ def update_document(self, definition, allow_unsafe_updates=False): # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self.db_connection() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type(id_) if record is None: raise KeyError('%s is not a valid MetadataType id') @@ -179,7 +179,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self.db_connection() as connection: + with self._db_connection() as connection: record = connection.get_metadata_type_by_name(name) if not record: raise KeyError('%s is not a valid MetadataType name' % name) @@ -195,7 +195,7 @@ def check_field_indexes(self, allow_table_lock=False, If false, creation will be slightly slower and cannot be done in a transaction. """ - with self.db_connection(transaction=allow_table_lock) as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.check_dynamic_fields( concurrently=not allow_table_lock, rebuild_indexes=rebuild_indexes, @@ -208,7 +208,7 @@ def get_all(self): :rtype: iter[datacube.model.MetadataType] """ - with self.db_connection() as connection: + with self._db_connection() as connection: return self._make_many(connection.get_all_metadata_types()) def _make_many(self, query_rows): diff --git a/datacube/index/postgres/_products.py b/datacube/index/postgres/_products.py index ca55df37fc..d0dff2be40 100644 --- a/datacube/index/postgres/_products.py +++ b/datacube/index/postgres/_products.py @@ -78,7 +78,7 @@ def add(self, product, allow_table_lock=False): _LOG.warning('Adding metadata_type "%s" as it doesn\'t exist.', product.metadata_type.name) metadata_type = self.metadata_type_resource.add(product.metadata_type, allow_table_lock=allow_table_lock) - with self.db_connection(transaction=allow_table_lock) as connection: + with self._db_connection(transaction=allow_table_lock) as connection: connection.insert_product( name=product.name, metadata=product.metadata_doc, @@ -190,7 +190,7 @@ def update(self, product: DatasetType, allow_unsafe_updates=False, allow_table_l metadata_type = cast(MetadataType, self.metadata_type_resource.get_by_name(product.metadata_type.name)) # Given we cannot change metadata type because of the check above, and this is an # update method, the metadata type is guaranteed to already exist. - with self.db_connection(transaction=allow_table_lock) as conn: + with self._db_connection(transaction=allow_table_lock) as conn: conn.update_product( name=product.name, metadata=product.metadata_doc, @@ -228,7 +228,7 @@ def update_document(self, definition, allow_unsafe_updates=False, allow_table_lo # This is memoized in the constructor # pylint: disable=method-hidden def get_unsafe(self, id_): # type: ignore - with self.db_connection() as connection: + with self._db_connection() as connection: result = connection.get_product(id_) if not result: raise KeyError('"%s" is not a valid Product id' % id_) @@ -237,7 +237,7 @@ def get_unsafe(self, id_): # type: ignore # This is memoized in the constructor # pylint: disable=method-hidden def get_by_name_unsafe(self, name): # type: ignore - with self.db_connection() as connection: + with self._db_connection() as connection: result = connection.get_product_by_name(name) if not result: raise KeyError('"%s" is not a valid Product name' % name) @@ -309,7 +309,7 @@ def get_all(self) -> Iterable[DatasetType]: """ Retrieve all Products """ - with self.db_connection() as connection: + with self._db_connection() as connection: return (self._make(record) for record in connection.get_all_products()) def _make_many(self, query_rows): diff --git a/datacube/index/postgres/_transaction.py b/datacube/index/postgres/_transaction.py index a8d6bcfc10..e6afe30072 100644 --- a/datacube/index/postgres/_transaction.py +++ b/datacube/index/postgres/_transaction.py @@ -39,7 +39,7 @@ def _release_connection(self) -> None: class IndexResourceAddIn: @contextmanager - def db_connection(self, transaction: bool = False) -> PostgresDbAPI: + def _db_connection(self, transaction: bool = False) -> PostgresDbAPI: """ Context manager representing a database connection. diff --git a/datacube/index/postgres/_users.py b/datacube/index/postgres/_users.py index 073f1d5ced..ae6ae26c9f 100644 --- a/datacube/index/postgres/_users.py +++ b/datacube/index/postgres/_users.py @@ -20,7 +20,7 @@ def grant_role(self, role: str, *usernames: str) -> None: """ Grant a role to users """ - with self.db_connection() as connection: + with self._db_connection() as connection: connection.grant_role(role, usernames) def create_user(self, username: str, password: str, @@ -28,14 +28,14 @@ def create_user(self, username: str, password: str, """ Create a new user. """ - with self.db_connection() as connection: + with self._db_connection() as connection: connection.create_user(username, password, role, description=description) def delete_user(self, *usernames: str) -> None: """ Delete a user """ - with self.db_connection() as connection: + with self._db_connection() as connection: connection.drop_users(usernames) def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: @@ -43,6 +43,6 @@ def list_users(self) -> Iterable[Tuple[str, str, Optional[str]]]: :return: list of (role, user, description) :rtype: list[(str, str, str)] """ - with self.db_connection() as connection: + with self._db_connection() as connection: for role, user, description in connection.list_users(): yield role, user, description diff --git a/integration_tests/index/test_index_data.py b/integration_tests/index/test_index_data.py index d448d8f312..0348220899 100755 --- a/integration_tests/index/test_index_data.py +++ b/integration_tests/index/test_index_data.py @@ -358,10 +358,10 @@ def test_transactions_api_manual(index, def test_transactions_api_hybrid(index, - extended_eo3_metadata_type_doc, - ls8_eo3_product, - eo3_ls8_dataset_doc, - eo3_ls8_dataset2_doc): + extended_eo3_metadata_type_doc, + ls8_eo3_product, + eo3_ls8_dataset_doc, + eo3_ls8_dataset2_doc): from datacube.index.hl import Doc2Dataset import logging _LOG = logging.getLogger(__name__) From 74781bc8149d8a8a5f35bbfc547bd8b4ece2b67f Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Wed, 21 Sep 2022 14:56:59 +1000 Subject: [PATCH 12/18] Remove redundant internal APIs and fix tests. Still a bit cluttered. --- datacube/drivers/postgis/_api.py | 3 + datacube/drivers/postgis/_connections.py | 27 +------- datacube/drivers/postgres/_api.py | 3 + datacube/drivers/postgres/_connections.py | 27 +------- datacube/index/postgis/_transaction.py | 12 +++- datacube/index/postgis/index.py | 2 +- datacube/index/postgres/_transaction.py | 12 +++- integration_tests/index/test_config_docs.py | 2 +- integration_tests/index/test_index_data.py | 62 ++++++------------- integration_tests/index/test_search_legacy.py | 14 ++--- .../index/test_update_columns.py | 6 +- integration_tests/test_config_tool.py | 28 ++++----- tests/index/test_api_index_dataset.py | 17 +++-- 13 files changed, 85 insertions(+), 130 deletions(-) diff --git a/datacube/drivers/postgis/_api.py b/datacube/drivers/postgis/_api.py index 49f758b95c..493231d271 100644 --- a/datacube/drivers/postgis/_api.py +++ b/datacube/drivers/postgis/_api.py @@ -182,6 +182,9 @@ def __init__(self, parentdb, connection): def in_transaction(self): return self._connection.in_transaction() + def begin(self): + self._connection.execute(text('BEGIN')) + def commit(self): self._connection.execute(text('COMMIT')) diff --git a/datacube/drivers/postgis/_connections.py b/datacube/drivers/postgis/_connections.py index 6eb0fd8c90..8b6015fc49 100755 --- a/datacube/drivers/postgis/_connections.py +++ b/datacube/drivers/postgis/_connections.py @@ -245,7 +245,7 @@ def spatial_indexes(self, refresh=False) -> Iterable[CRS]: return list(self.spindexes.keys()) @contextmanager - def connect(self): + def _connect(self): """ Borrow a connection from the pool. @@ -258,35 +258,12 @@ def connect(self): The connection can raise errors if not following this advice ("server closed the connection unexpectedly"), as some servers will aggressively close idle connections (eg. DEA's NCI servers). It also prevents the connection from being reused while borrowed. - """ - with self._engine.connect() as connection: - yield _api.PostgisDbAPI(self, connection) - connection.close() - - @contextmanager - def begin(self): - """ - Start a transaction. - - Returns an instance that will maintain a single connection in a transaction. - - Call commit() or rollback() to complete the transaction or use a context manager: - - with db.begin() as trans: - trans.insert_dataset(...) - - (Don't share an instance between threads) - :rtype: PostgresDBAPI + Low level context manager, user self._db_connection instead """ with self._engine.connect() as connection: - connection.execute(text('BEGIN')) try: yield _api.PostgisDbAPI(self, connection) - connection.execute(text('COMMIT')) - except Exception: # pylint: disable=broad-except - connection.execute(text('ROLLBACK')) - raise finally: connection.close() diff --git a/datacube/drivers/postgres/_api.py b/datacube/drivers/postgres/_api.py index 36c3c54408..8d02f24bd9 100644 --- a/datacube/drivers/postgres/_api.py +++ b/datacube/drivers/postgres/_api.py @@ -182,6 +182,9 @@ def __init__(self, connection): def in_transaction(self): return self._connection.in_transaction() + def begin(self): + self._connection.execute(text('BEGIN')) + def rollback(self): self._connection.execute(text('ROLLBACK')) diff --git a/datacube/drivers/postgres/_connections.py b/datacube/drivers/postgres/_connections.py index 56db60bfe7..733b6e827b 100755 --- a/datacube/drivers/postgres/_connections.py +++ b/datacube/drivers/postgres/_connections.py @@ -207,7 +207,7 @@ def init(self, with_permissions=True): return is_new @contextmanager - def connect(self): + def _connect(self): """ Borrow a connection from the pool. @@ -220,35 +220,12 @@ def connect(self): The connection can raise errors if not following this advice ("server closed the connection unexpectedly"), as some servers will aggressively close idle connections (eg. DEA's NCI servers). It also prevents the connection from being reused while borrowed. - """ - with self._engine.connect() as connection: - yield _api.PostgresDbAPI(connection) - connection.close() - - @contextmanager - def begin(self): - """ - Start a transaction. - - Returns an instance that will maintain a single connection in a transaction. - - Call commit() or rollback() to complete the transaction or use a context manager: - - with db.begin() as trans: - trans.insert_dataset(...) - - (Don't share an instance between threads) - :rtype: PostgresDBAPI + Low level context manager, user self._db_connection instead """ with self._engine.connect() as connection: - connection.execute(text('BEGIN')) try: yield _api.PostgresDbAPI(connection) - connection.execute(text('COMMIT')) - except Exception: # pylint: disable=broad-except - connection.execute(text('ROLLBACK')) - raise finally: connection.close() diff --git a/datacube/index/postgis/_transaction.py b/datacube/index/postgis/_transaction.py index 130fc6fe8c..7add9df37a 100644 --- a/datacube/index/postgis/_transaction.py +++ b/datacube/index/postgis/_transaction.py @@ -68,10 +68,16 @@ def _db_connection(self, transaction: bool = False) -> PostgisDbAPI: yield trans._connection elif transaction: closing = True - with self._db.begin() as conn: - yield conn + with self._db._connect() as conn: + conn.begin() + try: + yield conn + conn.commit() + except Exception: # pylint: disable=broad-except + conn.rollback() + raise else: closing = True # Autocommit behaviour: - with self._db.connect() as conn: + with self._db._connect() as conn: yield conn diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index 3c4d82aea4..1e407e2c7e 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -136,7 +136,7 @@ def update_spatial_index(self, product_names: Sequence[str] = [], dataset_ids: Sequence[DSID] = [] ) -> int: - with self._db.connect() as conn: + with self.datasets._db_connection(transaction=True) as conn: return conn.update_spindex(crses, product_names, dataset_ids) def __repr__(self): diff --git a/datacube/index/postgres/_transaction.py b/datacube/index/postgres/_transaction.py index e6afe30072..c141e06c90 100644 --- a/datacube/index/postgres/_transaction.py +++ b/datacube/index/postgres/_transaction.py @@ -68,10 +68,16 @@ def _db_connection(self, transaction: bool = False) -> PostgresDbAPI: yield trans._connection elif transaction: closing = True - with self._db.begin() as conn: - yield conn + with self._db._connect() as conn: + conn.begin() + try: + yield conn + conn.commit() + except Exception: # pylint: disable=broad-except + conn.rollback() + raise else: closing = True # Autocommit behaviour: - with self._db.connect() as conn: + with self._db._connect() as conn: yield conn diff --git a/integration_tests/index/test_config_docs.py b/integration_tests/index/test_config_docs.py index 1be00ee5bb..a3f15b1345 100644 --- a/integration_tests/index/test_config_docs.py +++ b/integration_tests/index/test_config_docs.py @@ -214,7 +214,7 @@ def _object_exists(db, index_name): schema_name = "odc" else: schema_name = "agdc" - with db.connect() as connection: + with db._connect() as connection: val = connection._connection.execute(f"SELECT to_regclass('{schema_name}.{index_name}')").scalar() return val in (index_name, f'{schema_name}.{index_name}') diff --git a/integration_tests/index/test_index_data.py b/integration_tests/index/test_index_data.py index 0348220899..1fe81cb5cd 100755 --- a/integration_tests/index/test_index_data.py +++ b/integration_tests/index/test_index_data.py @@ -71,10 +71,10 @@ } -def test_archive_datasets(index, initialised_postgres_db, local_config, default_metadata_type): +def test_archive_datasets(index, local_config, default_metadata_type): dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -106,11 +106,11 @@ def test_archive_datasets(index, initialised_postgres_db, local_config, default_ assert not indexed_dataset.is_archived -def test_purge_datasets(index, initialised_postgres_db, local_config, default_metadata_type, clirunner): +def test_purge_datasets(index, local_config, default_metadata_type, clirunner): # Create dataset dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -137,15 +137,15 @@ def test_purge_datasets(index, initialised_postgres_db, local_config, default_me assert index.datasets.get(_telemetry_uuid) is None -def test_purge_datasets_cli(index, initialised_postgres_db, local_config, default_metadata_type, clirunner): +def test_purge_datasets_cli(index, local_config, default_metadata_type, clirunner): dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) # Attempt to purge non-existent dataset should fail clirunner(['dataset', 'purge', str(_telemetry_uuid)], expect_success=False) # Create dataset - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -170,12 +170,12 @@ def test_purge_datasets_cli(index, initialised_postgres_db, local_config, defaul assert index.datasets.get(_telemetry_uuid) is None -def test_purge_all_datasets_cli(index, initialised_postgres_db, local_config, default_metadata_type, clirunner): +def test_purge_all_datasets_cli(index, local_config, default_metadata_type, clirunner): dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) # Create dataset - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -202,12 +202,12 @@ def test_purge_all_datasets_cli(index, initialised_postgres_db, local_config, de @pytest.fixture -def telemetry_dataset(index: Index, initialised_postgres_db: PostgresDb, default_metadata_type) -> Dataset: +def telemetry_dataset(index: Index, default_metadata_type) -> Dataset: dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) assert not index.datasets.has(_telemetry_uuid) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -217,14 +217,14 @@ def telemetry_dataset(index: Index, initialised_postgres_db: PostgresDb, default return index.datasets.get(_telemetry_uuid) -def test_index_duplicate_dataset(index: Index, initialised_postgres_db: PostgresDb, +def test_index_duplicate_dataset(index: Index, local_config, default_metadata_type) -> None: dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) assert not index.datasets.has(_telemetry_uuid) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( + with index.transaction() as transaction: + was_inserted = transaction._connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, dataset_type.id @@ -234,7 +234,7 @@ def test_index_duplicate_dataset(index: Index, initialised_postgres_db: Postgres assert index.datasets.has(_telemetry_uuid) # Insert again. - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: was_inserted = connection.insert_dataset( _telemetry_dataset, _telemetry_uuid, @@ -273,30 +273,6 @@ def test_get_dataset(index: Index, telemetry_dataset: Dataset) -> None: 'f226a278-e422-11e6-b501-185e0f80a5c1']) == [] -def test_transactions_internal_api(index: Index, - initialised_postgres_db: PostgresDb, - local_config, - default_metadata_type) -> None: - assert not index.datasets.has(_telemetry_uuid) - - dataset_type = index.products.add_document(_pseudo_telemetry_dataset_type) - with initialised_postgres_db.begin() as transaction: - was_inserted = transaction.insert_dataset( - _telemetry_dataset, - _telemetry_uuid, - dataset_type.id - ) - assert was_inserted - assert transaction.contains_dataset(_telemetry_uuid) - # Normal DB uses a separate connection: No dataset visible yet. - assert not index.datasets.has(_telemetry_uuid) - - transaction.rollback() - - # Should have been rolled back. - assert not index.datasets.has(_telemetry_uuid) - - def test_transactions_api_ctx_mgr(index, extended_eo3_metadata_type_doc, ls8_eo3_product, diff --git a/integration_tests/index/test_search_legacy.py b/integration_tests/index/test_search_legacy.py index 92e9b82e1a..6580b0e561 100644 --- a/integration_tests/index/test_search_legacy.py +++ b/integration_tests/index/test_search_legacy.py @@ -55,9 +55,9 @@ def pseudo_ls8_type(index, ga_metadata_type): @pytest.fixture -def pseudo_ls8_dataset(index, initialised_postgres_db, pseudo_ls8_type): +def pseudo_ls8_dataset(index, pseudo_ls8_type): id_ = str(uuid.uuid4()) - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: was_inserted = connection.insert_dataset( { 'id': id_, @@ -107,10 +107,10 @@ def pseudo_ls8_dataset(index, initialised_postgres_db, pseudo_ls8_type): @pytest.fixture -def pseudo_ls8_dataset2(index, initialised_postgres_db, pseudo_ls8_type): +def pseudo_ls8_dataset2(index, pseudo_ls8_type): # Like the previous dataset, but a day later in time. id_ = str(uuid.uuid4()) - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: was_inserted = connection.insert_dataset( { 'id': id_, @@ -162,7 +162,6 @@ def pseudo_ls8_dataset2(index, initialised_postgres_db, pseudo_ls8_type): # Datasets 3 and 4 mirror 1 and 2 but have a different path/row. @pytest.fixture def pseudo_ls8_dataset3(index: Index, - initialised_postgres_db: PostgresDb, pseudo_ls8_type: Product, pseudo_ls8_dataset: Dataset) -> Dataset: # Same as 1, but a different path/row @@ -174,7 +173,7 @@ def pseudo_ls8_dataset3(index: Index, 'satellite_ref_point_end': {'x': 116, 'y': 87}, } - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: was_inserted = connection.insert_dataset( dataset_doc, id_, @@ -189,7 +188,6 @@ def pseudo_ls8_dataset3(index: Index, @pytest.fixture def pseudo_ls8_dataset4(index: Index, - initialised_postgres_db: PostgresDb, pseudo_ls8_type: Product, pseudo_ls8_dataset2: Dataset) -> Dataset: # Same as 2, but a different path/row @@ -201,7 +199,7 @@ def pseudo_ls8_dataset4(index: Index, 'satellite_ref_point_end': {'x': 116, 'y': 87}, } - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: was_inserted = connection.insert_dataset( dataset_doc, id_, diff --git a/integration_tests/index/test_update_columns.py b/integration_tests/index/test_update_columns.py index 9e34575ae8..e64b518246 100644 --- a/integration_tests/index/test_update_columns.py +++ b/integration_tests/index/test_update_columns.py @@ -58,7 +58,7 @@ def test_added_column(clirunner, uninitialised_postgres_db): result = clirunner(["system", "init"]) assert "Created." in result.output - with uninitialised_postgres_db.connect() as connection: + with uninitialised_postgres_db._connect() as connection: assert check_column(connection, _schema.METADATA_TYPE.name, "updated") assert not check_column(connection, _schema.METADATA_TYPE.name, "fake_column") assert check_column(connection, _schema.PRODUCT.name, "updated") @@ -81,7 +81,7 @@ def test_readd_column(clirunner, uninitialised_postgres_db): result = clirunner(["system", "init"]) assert "Created." in result.output - with uninitialised_postgres_db.connect() as connection: + with uninitialised_postgres_db._connect() as connection: # Drop all the columns for an init rerun drop_column(connection, _schema.METADATA_TYPE.name, "updated") drop_column(connection, _schema.PRODUCT.name, "updated") @@ -95,7 +95,7 @@ def test_readd_column(clirunner, uninitialised_postgres_db): result = clirunner(["system", "init"]) - with uninitialised_postgres_db.connect() as connection: + with uninitialised_postgres_db._connect() as connection: assert check_column(connection, _schema.METADATA_TYPE.name, "updated") assert check_column(connection, _schema.PRODUCT.name, "updated") assert check_column(connection, _schema.DATASET.name, "updated") diff --git a/integration_tests/test_config_tool.py b/integration_tests/test_config_tool.py index d333058d3f..16dfb6f1c0 100644 --- a/integration_tests/test_config_tool.py +++ b/integration_tests/test_config_tool.py @@ -21,11 +21,11 @@ def _dataset_type_count(db): - with db.connect() as connection: + with db._connect() as connection: return len(list(connection.get_all_products())) -def test_add_example_dataset_types(clirunner, initialised_postgres_db, default_metadata_type): +def test_add_example_dataset_types(clirunner, index, default_metadata_type): """ Add example mapping docs, to ensure they're valid and up-to-date. @@ -33,7 +33,7 @@ def test_add_example_dataset_types(clirunner, initialised_postgres_db, default_m :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb """ - existing_mappings = _dataset_type_count(initialised_postgres_db) + existing_mappings = _dataset_type_count(index._db) print('{} mappings'.format(existing_mappings)) for mapping_path in EXAMPLE_DATASET_TYPE_DOCS: @@ -42,7 +42,7 @@ def test_add_example_dataset_types(clirunner, initialised_postgres_db, default_m result = clirunner(['-v', 'product', 'add', mapping_path]) assert result.exit_code == 0 - mappings_count = _dataset_type_count(initialised_postgres_db) + mappings_count = _dataset_type_count(index._db) assert mappings_count > existing_mappings, "Mapping document was not added: " + str(mapping_path) existing_mappings = mappings_count @@ -178,27 +178,27 @@ def test_db_init_rebuild(clirunner, local_config, ls5_telem_type): ) in result.output -def test_db_init(clirunner, initialised_postgres_db): - if initialised_postgres_db.driver_name == "postgis": +def test_db_init(clirunner, index): + if index._db.driver_name == "postgis": from datacube.drivers.postgis._core import drop_db, has_schema else: from datacube.drivers.postgres._core import drop_db, has_schema - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: drop_db(connection._connection) - assert not has_schema(initialised_postgres_db._engine, connection._connection) + assert not has_schema(index._db._engine, connection._connection) # Run on an empty database. - if initialised_postgres_db.driver_name == "postgis": + if index._db.driver_name == "postgis": result = clirunner(['-E', 'experimental', 'system', 'init']) else: result = clirunner(['system', 'init']) assert 'Created.' in result.output - with initialised_postgres_db.connect() as connection: - assert has_schema(initialised_postgres_db._engine, connection._connection) + with index._db._connect() as connection: + assert has_schema(index._db._engine, connection._connection) def test_add_no_such_product(clirunner, initialised_postgres_db): @@ -214,13 +214,13 @@ def test_add_no_such_product(clirunner, initialised_postgres_db): # Test that names are escaped ('test_user_"invalid+_chars_{n}', None), ('test_user_invalid_desc_{n}', 'Invalid "\' chars in description')]) -def example_user(clirunner, initialised_postgres_db, request): +def example_user(clirunner, index, request): username, description = request.param username = username.format(n=random.randint(111111, 999999)) # test_roles = (user_name for role_name, user_name, desc in roles if user_name.startswith('test_')) - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: users = (user_name for role_name, user_name, desc in connection.list_users()) if username in users: connection.drop_users([username]) @@ -230,7 +230,7 @@ def example_user(clirunner, initialised_postgres_db, request): yield username, description - with initialised_postgres_db.connect() as connection: + with index._db._connect() as connection: users = (user_name for role_name, user_name, desc in connection.list_users()) if username in users: connection.drop_users([username]) diff --git a/tests/index/test_api_index_dataset.py b/tests/index/test_api_index_dataset.py index a8ff2e8aae..ce0056ed77 100644 --- a/tests/index/test_api_index_dataset.py +++ b/tests/index/test_api_index_dataset.py @@ -166,12 +166,17 @@ def __init__(self): self.dataset_source = set() @contextmanager - def begin(self): + def _connect(self): yield self - @contextmanager - def connect(self): - yield self + def begin(self): + pass + + def commit(self): + pass + + def rollback(self): + pass def get_dataset(self, id): return self.dataset.get(id, None) @@ -208,6 +213,10 @@ def get(self, *args, **kwargs): def get_by_name(self, *args, **kwargs): return self.type + @contextmanager + def _db_connection(self, transaction=False): + yield MockDb() + class MockIndex: def __init__(self, db, product): From c30f4d049db3e98787c6ecf3e59ec316a0e0914f Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Wed, 21 Sep 2022 17:20:23 +1000 Subject: [PATCH 13/18] Refactor. --- datacube/drivers/postgis/__init__.py | 3 +- datacube/drivers/postgres/__init__.py | 3 +- datacube/index/postgis/_datasets.py | 2 +- datacube/index/postgis/_transaction.py | 22 ++---------- datacube/index/postgis/index.py | 46 +++++++++++++++++++++++-- datacube/index/postgres/_transaction.py | 22 ++---------- datacube/index/postgres/index.py | 43 ++++++++++++++++++++++- tests/index/test_api_index_dataset.py | 9 +++-- 8 files changed, 99 insertions(+), 51 deletions(-) diff --git a/datacube/drivers/postgis/__init__.py b/datacube/drivers/postgis/__init__.py index 273ac73ba3..50348f828d 100644 --- a/datacube/drivers/postgis/__init__.py +++ b/datacube/drivers/postgis/__init__.py @@ -9,5 +9,6 @@ """ from ._connections import PostGisDb +from ._api import PostgisDbAPI -__all__ = ['PostGisDb'] +__all__ = ['PostGisDb', 'PostgisDbAPI'] diff --git a/datacube/drivers/postgres/__init__.py b/datacube/drivers/postgres/__init__.py index a573d53814..6fe10be3c7 100644 --- a/datacube/drivers/postgres/__init__.py +++ b/datacube/drivers/postgres/__init__.py @@ -9,5 +9,6 @@ """ from ._connections import PostgresDb +from ._api import PostgresDbAPI -__all__ = ['PostgresDb'] +__all__ = ['PostgresDb', 'PostgresDbAPI'] diff --git a/datacube/index/postgis/_datasets.py b/datacube/index/postgis/_datasets.py index ec2b50f69d..e11646ade0 100755 --- a/datacube/index/postgis/_datasets.py +++ b/datacube/index/postgis/_datasets.py @@ -893,5 +893,5 @@ def get_custom_query_expressions(self, custom_query, custom_offsets): return custom_exprs def spatial_extent(self, ids: Iterable[DSID], crs: CRS = CRS("EPSG:4326")) -> Optional[Geometry]: - with self._db.connect() as connection: + with self._db_connection() as connection: return connection.spatial_extent(ids, crs) diff --git a/datacube/index/postgis/_transaction.py b/datacube/index/postgis/_transaction.py index 7add9df37a..956989eb47 100644 --- a/datacube/index/postgis/_transaction.py +++ b/datacube/index/postgis/_transaction.py @@ -61,23 +61,5 @@ def _db_connection(self, transaction: bool = False) -> PostgisDbAPI: :param transaction: Use a transaction if one is not already active for the thread. :return: A PostgresDbAPI object, with the specified transaction semantics. """ - trans = self._index.thread_transaction() - closing = False - if trans is not None: - # Use active transaction - yield trans._connection - elif transaction: - closing = True - with self._db._connect() as conn: - conn.begin() - try: - yield conn - conn.commit() - except Exception: # pylint: disable=broad-except - conn.rollback() - raise - else: - closing = True - # Autocommit behaviour: - with self._db._connect() as conn: - yield conn + with self._index._active_connection(transaction=transaction) as conn: + yield conn diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index 1e407e2c7e..75c55767b8 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -3,9 +3,10 @@ # Copyright (c) 2015-2020 ODC Contributors # SPDX-License-Identifier: Apache-2.0 import logging +from contextlib import contextmanager from typing import Any, Iterable, Sequence -from datacube.drivers.postgis import PostGisDb +from datacube.drivers.postgis import PostGisDb, PostgisDbAPI from datacube.index.postgis._transaction import PostgisTransaction from datacube.index.postgis._datasets import DatasetResource, DSID # type: ignore from datacube.index.postgis._metadata_types import MetadataTypeResource @@ -136,12 +137,53 @@ def update_spatial_index(self, product_names: Sequence[str] = [], dataset_ids: Sequence[DSID] = [] ) -> int: - with self.datasets._db_connection(transaction=True) as conn: + with self._active_connection(transaction=True) as conn: return conn.update_spindex(crses, product_names, dataset_ids) def __repr__(self): return "Index".format(self._db) + @contextmanager + def _active_connection(self, transaction: bool = False) -> PostgisDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + trans = self.thread_transaction() + closing = False + if trans is not None: + # Use active transaction + yield trans._connection + elif transaction: + closing = True + with self._db._connect() as conn: + conn.begin() + try: + yield conn + conn.commit() + except Exception: # pylint: disable=broad-except + conn.rollback() + raise + else: + closing = True + # Autocommit behaviour: + with self._db._connect() as conn: + yield conn + class DefaultIndexDriver(AbstractIndexDriver): @staticmethod diff --git a/datacube/index/postgres/_transaction.py b/datacube/index/postgres/_transaction.py index c141e06c90..4585e96e1e 100644 --- a/datacube/index/postgres/_transaction.py +++ b/datacube/index/postgres/_transaction.py @@ -61,23 +61,5 @@ def _db_connection(self, transaction: bool = False) -> PostgresDbAPI: :param transaction: Use a transaction if one is not already active for the thread. :return: A PostgresDbAPI object, with the specified transaction semantics. """ - trans = self._index.thread_transaction() - closing = False - if trans is not None: - # Use active transaction - yield trans._connection - elif transaction: - closing = True - with self._db._connect() as conn: - conn.begin() - try: - yield conn - conn.commit() - except Exception: # pylint: disable=broad-except - conn.rollback() - raise - else: - closing = True - # Autocommit behaviour: - with self._db._connect() as conn: - yield conn + with self._index._active_connection(transaction=transaction) as conn: + yield conn diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index 48be6456c5..8832e1e621 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from typing import Any -from datacube.drivers.postgres import PostgresDb +from datacube.drivers.postgres import PostgresDb, PostgresDbAPI from datacube.index.postgres._transaction import PostgresTransaction from datacube.index.postgres._datasets import DatasetResource # type: ignore from datacube.index.postgres._metadata_types import MetadataTypeResource @@ -117,6 +117,47 @@ def create_spatial_index(self, crs: CRS) -> None: def __repr__(self): return "Index".format(self._db) + @contextmanager + def _active_connection(self, transaction: bool = False) -> PostgresDbAPI: + """ + Context manager representing a database connection. + + If there is an active transaction for this index in the current thread, the connection object from that + transaction is returned, with the active transaction remaining in control of commit and rollback. + + If there is no active transaction and the transaction argument is True, a new transactionised connection + is returned, with this context manager handling commit and rollback. + + If there is no active transaction and the transaction argument is False (the default), a new connection + is returned with autocommit semantics. + + Note that autocommit behaviour is NOT available if there is an active transaction for the index + and the active thread. + + :param transaction: Use a transaction if one is not already active for the thread. + :return: A PostgresDbAPI object, with the specified transaction semantics. + """ + trans = self.thread_transaction() + closing = False + if trans is not None: + # Use active transaction + yield trans._connection + elif transaction: + closing = True + with self._db._connect() as conn: + conn.begin() + try: + yield conn + conn.commit() + except Exception: # pylint: disable=broad-except + conn.rollback() + raise + else: + closing = True + # Autocommit behaviour: + with self._db._connect() as conn: + yield conn + class DefaultIndexDriver(AbstractIndexDriver): aliases = ['postgres'] diff --git a/tests/index/test_api_index_dataset.py b/tests/index/test_api_index_dataset.py index ce0056ed77..93a57a84ee 100644 --- a/tests/index/test_api_index_dataset.py +++ b/tests/index/test_api_index_dataset.py @@ -155,11 +155,6 @@ def _build_dataset(doc): 'added', 'added_by', 'archived']) -class MockIndex(object): - def __init__(self, db): - self._db = db - - class MockDb(object): def __init__(self): self.dataset = {} @@ -226,6 +221,10 @@ def __init__(self, db, product): def thread_transaction(self): return None + @contextmanager + def _active_connection(self, transaction=False): + yield self._db + def test_index_dataset(): mock_db = MockDb() From 44ae2b5c2834c2a3c21cc16fa6028c3080758fe3 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Wed, 21 Sep 2022 17:43:27 +1000 Subject: [PATCH 14/18] Cleanup tests. --- integration_tests/index/test_config_docs.py | 42 +++++++++---------- integration_tests/index/test_search_legacy.py | 10 ++--- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/integration_tests/index/test_config_docs.py b/integration_tests/index/test_config_docs.py index a3f15b1345..328aab35bf 100644 --- a/integration_tests/index/test_config_docs.py +++ b/integration_tests/index/test_config_docs.py @@ -54,20 +54,20 @@ @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_metadata_indexes_views_exist(initialised_postgres_db, default_metadata_type): +def test_metadata_indexes_views_exist(index, default_metadata_type): """ :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb :type default_metadata_type: datacube.model.MetadataType """ # Metadata indexes should no longer exist. - assert not _object_exists(initialised_postgres_db, 'dix_eo_platform') + assert not _object_exists(index, 'dix_eo_platform') # Ensure view was created (following naming conventions) - assert _object_exists(initialised_postgres_db, 'dv_eo_dataset') + assert _object_exists(index, 'dv_eo_dataset') @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_dataset_indexes_views_exist(initialised_postgres_db, ls5_telem_type): +def test_dataset_indexes_views_exist(index, ls5_telem_type): """ :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb :type ls5_telem_type: datacube.model.DatasetType @@ -75,30 +75,30 @@ def test_dataset_indexes_views_exist(initialised_postgres_db, ls5_telem_type): assert ls5_telem_type.name == 'ls5_telem_test' # Ensure field indexes were created for the dataset type (following the naming conventions): - assert _object_exists(initialised_postgres_db, "dix_ls5_telem_test_orbit") + assert _object_exists(index, "dix_ls5_telem_test_orbit") # Ensure it does not create a 'platform' index, because that's a fixed field # (ie. identical in every dataset of the type) - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_platform") + assert not _object_exists(index, "dix_ls5_telem_test_platform") # Ensure view was created (following naming conventions) - assert _object_exists(initialised_postgres_db, 'dv_ls5_telem_test_dataset') + assert _object_exists(index, 'dv_ls5_telem_test_dataset') # Ensure view was created (following naming conventions) - assert not _object_exists(initialised_postgres_db, + assert not _object_exists(index, 'dix_ls5_telem_test_gsi'), "indexed=false field gsi shouldn't have an index" @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_dataset_composite_indexes_exist(initialised_postgres_db, ls5_telem_type): +def test_dataset_composite_indexes_exist(index, ls5_telem_type): # This type has fields named lat/lon/time, so composite indexes should now exist for them: # (following the naming conventions) - assert _object_exists(initialised_postgres_db, "dix_ls5_telem_test_sat_path_sat_row_time") + assert _object_exists(index, "dix_ls5_telem_test_sat_path_sat_row_time") # But no individual field indexes for these - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_sat_path") - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_sat_row") - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_time") + assert not _object_exists(index, "dix_ls5_telem_test_sat_path") + assert not _object_exists(index, "dix_ls5_telem_test_sat_row") + assert not _object_exists(index, "dix_ls5_telem_test_time") @pytest.mark.parametrize('datacube_env_name', ('datacube', )) @@ -209,12 +209,12 @@ def test_field_expression_unchanged_postgis( ) -def _object_exists(db, index_name): - if db.driver_name == "postgis": +def _object_exists(index, index_name): + if index._db.driver_name == "postgis": schema_name = "odc" else: schema_name = "agdc" - with db._connect() as connection: + with index._active_connection() as connection: val = connection._connection.execute(f"SELECT to_regclass('{schema_name}.{index_name}')").scalar() return val in (index_name, f'{schema_name}.{index_name}') @@ -337,11 +337,11 @@ def test_update_dataset_type(index, ls5_telem_type, ls5_telem_doc, ga_metadata_t index.products.update_document(full_doc) # Remove fixed field, forcing a new index to be created (as datasets can now differ for the field). - assert not _object_exists(index._db, 'dix_ls5_telem_test_product_type') + assert not _object_exists(index, 'dix_ls5_telem_test_product_type') del ls5_telem_doc['metadata']['product_type'] index.products.update_document(ls5_telem_doc) # Ensure was updated - assert _object_exists(index._db, 'dix_ls5_telem_test_product_type') + assert _object_exists(index, 'dix_ls5_telem_test_product_type') updated_type = index.products.get_by_name(ls5_telem_type.name) assert updated_type.definition['metadata'] == ls5_telem_doc['metadata'] @@ -548,7 +548,7 @@ def test_filter_types_by_search(index, ls5_telem_type): @pytest.mark.parametrize('datacube_env_name', ('datacube', )) -def test_update_metadata_type_doc(initialised_postgres_db, index, ls5_telem_type): +def test_update_metadata_type_doc(index, ls5_telem_type): type_doc = copy.deepcopy(ls5_telem_type.metadata_type.definition) type_doc['dataset']['search_fields']['test_indexed'] = { 'description': 'indexed test field', @@ -563,5 +563,5 @@ def test_update_metadata_type_doc(initialised_postgres_db, index, ls5_telem_type index.metadata_types.update_document(type_doc) assert ls5_telem_type.name == 'ls5_telem_test' - assert _object_exists(initialised_postgres_db, "dix_ls5_telem_test_test_indexed") - assert not _object_exists(initialised_postgres_db, "dix_ls5_telem_test_test_not_indexed") + assert _object_exists(index, "dix_ls5_telem_test_test_indexed") + assert not _object_exists(index, "dix_ls5_telem_test_test_not_indexed") diff --git a/integration_tests/index/test_search_legacy.py b/integration_tests/index/test_search_legacy.py index 6580b0e561..07e4641f29 100644 --- a/integration_tests/index/test_search_legacy.py +++ b/integration_tests/index/test_search_legacy.py @@ -57,7 +57,7 @@ def pseudo_ls8_type(index, ga_metadata_type): @pytest.fixture def pseudo_ls8_dataset(index, pseudo_ls8_type): id_ = str(uuid.uuid4()) - with index._db._connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( { 'id': id_, @@ -110,7 +110,7 @@ def pseudo_ls8_dataset(index, pseudo_ls8_type): def pseudo_ls8_dataset2(index, pseudo_ls8_type): # Like the previous dataset, but a day later in time. id_ = str(uuid.uuid4()) - with index._db._connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( { 'id': id_, @@ -173,7 +173,7 @@ def pseudo_ls8_dataset3(index: Index, 'satellite_ref_point_end': {'x': 116, 'y': 87}, } - with index._db._connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( dataset_doc, id_, @@ -199,7 +199,7 @@ def pseudo_ls8_dataset4(index: Index, 'satellite_ref_point_end': {'x': 116, 'y': 87}, } - with index._db._connect() as connection: + with index._active_connection() as connection: was_inserted = connection.insert_dataset( dataset_doc, id_, @@ -853,7 +853,7 @@ def test_cli_info(index: Index, assert yaml_docs[1]['id'] == str(pseudo_ls8_dataset2.id) -def test_cli_missing_info(clirunner, initialised_postgres_db): +def test_cli_missing_info(clirunner, index): id_ = str(uuid.uuid4()) result = clirunner( [ From 33865f02f29dbfdd9378fdd38d140fa9b551f110 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Wed, 21 Sep 2022 17:51:59 +1000 Subject: [PATCH 15/18] Further cleanup of tests. --- integration_tests/conftest.py | 9 --------- integration_tests/test_config_tool.py | 25 ++++++++++--------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index e59e76f80a..1a834a0984 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -376,15 +376,6 @@ def index_empty(local_config, uninitialised_postgres_db: PostgresDb): del index -@pytest.fixture -def initialised_postgres_db(index): - """ - Return a connection to an PostgreSQL (or Postgis) database, initialised with the default schema - and tables. - """ - return index._db - - def remove_postgres_dynamic_indexes(): """ Clear any dynamically created postgresql indexes from the schema. diff --git a/integration_tests/test_config_tool.py b/integration_tests/test_config_tool.py index 16dfb6f1c0..ef6cc5c53c 100644 --- a/integration_tests/test_config_tool.py +++ b/integration_tests/test_config_tool.py @@ -20,8 +20,8 @@ INVALID_MAPPING_DOCS = map(str, Path(__file__).parent.parent.joinpath('docs').glob('*')) -def _dataset_type_count(db): - with db._connect() as connection: +def _dataset_type_count(index): + with index._active_connection() as connection: return len(list(connection.get_all_products())) @@ -30,10 +30,8 @@ def test_add_example_dataset_types(clirunner, index, default_metadata_type): Add example mapping docs, to ensure they're valid and up-to-date. We add them all to a single database to check for things like duplicate ids. - - :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb """ - existing_mappings = _dataset_type_count(index._db) + existing_mappings = _dataset_type_count(index) print('{} mappings'.format(existing_mappings)) for mapping_path in EXAMPLE_DATASET_TYPE_DOCS: @@ -42,7 +40,7 @@ def test_add_example_dataset_types(clirunner, index, default_metadata_type): result = clirunner(['-v', 'product', 'add', mapping_path]) assert result.exit_code == 0 - mappings_count = _dataset_type_count(index._db) + mappings_count = _dataset_type_count(index) assert mappings_count > existing_mappings, "Mapping document was not added: " + str(mapping_path) existing_mappings = mappings_count @@ -79,11 +77,8 @@ def test_add_example_dataset_types(clirunner, index, default_metadata_type): assert result.exit_code == 0 -def test_error_returned_on_invalid(clirunner, initialised_postgres_db): - """ - :type initialised_postgres_db: datacube.drivers.postgres._connections.PostgresDb - """ - assert _dataset_type_count(initialised_postgres_db) == 0 +def test_error_returned_on_invalid(clirunner, index): + assert _dataset_type_count(index) == 0 for mapping_path in INVALID_MAPPING_DOCS: result = clirunner( @@ -95,10 +90,10 @@ def test_error_returned_on_invalid(clirunner, initialised_postgres_db): expect_success=False ) assert result.exit_code != 0, "Success return code for invalid document." - assert _dataset_type_count(initialised_postgres_db) == 0, "Invalid document was added to DB" + assert _dataset_type_count(index) == 0, "Invalid document was added to DB" -def test_config_check(clirunner, initialised_postgres_db, local_config): +def test_config_check(clirunner, index, local_config): """ :type local_config: datacube.config.LocalConfig """ @@ -119,7 +114,7 @@ def test_config_check(clirunner, initialised_postgres_db, local_config): assert user_regex.match(result.output) -def test_list_users_does_not_fail(clirunner, local_config, initialised_postgres_db): +def test_list_users_does_not_fail(clirunner, local_config, index): """ :type local_config: datacube.config.LocalConfig """ @@ -201,7 +196,7 @@ def test_db_init(clirunner, index): assert has_schema(index._db._engine, connection._connection) -def test_add_no_such_product(clirunner, initialised_postgres_db): +def test_add_no_such_product(clirunner, index): result = clirunner(['dataset', 'add', '--dtype', 'no_such_product', '/tmp'], expect_success=False) assert result.exit_code != 0 assert "DEPRECATED option detected" in result.output From 0a7a755e54ab18e78efee28b38f0317334fd5552 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Wed, 21 Sep 2022 18:07:15 +1000 Subject: [PATCH 16/18] Flake8 appeasement. --- datacube/drivers/postgis/_connections.py | 2 +- datacube/drivers/postgres/_connections.py | 2 +- datacube/index/exceptions.py | 2 +- datacube/index/postgis/_transaction.py | 3 --- datacube/index/postgis/_users.py | 5 ++++- datacube/index/postgis/index.py | 2 +- datacube/index/postgres/_transaction.py | 3 --- datacube/index/postgres/_users.py | 5 ++++- datacube/index/postgres/index.py | 1 - integration_tests/index/test_index_data.py | 7 ------- integration_tests/index/test_search_legacy.py | 1 - 11 files changed, 12 insertions(+), 21 deletions(-) diff --git a/datacube/drivers/postgis/_connections.py b/datacube/drivers/postgis/_connections.py index 8b6015fc49..bbc8745205 100755 --- a/datacube/drivers/postgis/_connections.py +++ b/datacube/drivers/postgis/_connections.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from typing import Any, Callable, Iterable, Mapping, Optional, Union, Type -from sqlalchemy import event, create_engine, text +from sqlalchemy import event, create_engine from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL as EngineUrl # noqa: N811 diff --git a/datacube/drivers/postgres/_connections.py b/datacube/drivers/postgres/_connections.py index 733b6e827b..130cef55f3 100755 --- a/datacube/drivers/postgres/_connections.py +++ b/datacube/drivers/postgres/_connections.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from typing import Callable, Optional, Union -from sqlalchemy import event, create_engine, text +from sqlalchemy import event, create_engine from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL as EngineUrl # noqa: N811 diff --git a/datacube/index/exceptions.py b/datacube/index/exceptions.py index 282a6a59f0..e75a4135ae 100644 --- a/datacube/index/exceptions.py +++ b/datacube/index/exceptions.py @@ -16,7 +16,7 @@ class IndexSetupError(Exception): pass -class TransactionException(Exception): +class TransactionException(Exception): # noqa: N818 def __init__(self, *args, commit=False, **kwargs): super().__init__(*args, **kwargs) self.commit = commit diff --git a/datacube/index/postgis/_transaction.py b/datacube/index/postgis/_transaction.py index 956989eb47..60994be7ec 100644 --- a/datacube/index/postgis/_transaction.py +++ b/datacube/index/postgis/_transaction.py @@ -3,8 +3,6 @@ # Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 -import logging - from contextlib import contextmanager from sqlalchemy import text from typing import Any @@ -13,7 +11,6 @@ from datacube.drivers.postgis._api import PostgisDbAPI from datacube.index.abstract import AbstractTransaction -_LOG = logging.getLogger(__name__) class PostgisTransaction(AbstractTransaction): def __init__(self, db: PostGisDb, idx_id: str) -> None: diff --git a/datacube/index/postgis/_users.py b/datacube/index/postgis/_users.py index 1acb9a7f77..c604a676b4 100644 --- a/datacube/index/postgis/_users.py +++ b/datacube/index/postgis/_users.py @@ -9,7 +9,10 @@ class UserResource(AbstractUserResource, IndexResourceAddIn): - def __init__(self, db: PostGisDb, index: "datacube.index.postgis.index.Index") -> None: + def __init__(self, + db: PostGisDb, + index: "datacube.index.postgis.index.Index" # noqa: F821 + ) -> None: """ :type db: datacube.drivers.postgis.PostGisDb """ diff --git a/datacube/index/postgis/index.py b/datacube/index/postgis/index.py index 75c55767b8..db8366160a 100644 --- a/datacube/index/postgis/index.py +++ b/datacube/index/postgis/index.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging from contextlib import contextmanager -from typing import Any, Iterable, Sequence +from typing import Iterable, Sequence from datacube.drivers.postgis import PostGisDb, PostgisDbAPI from datacube.index.postgis._transaction import PostgisTransaction diff --git a/datacube/index/postgres/_transaction.py b/datacube/index/postgres/_transaction.py index 4585e96e1e..2c9a7a17e7 100644 --- a/datacube/index/postgres/_transaction.py +++ b/datacube/index/postgres/_transaction.py @@ -3,8 +3,6 @@ # Copyright (c) 2015-2022 ODC Contributors # SPDX-License-Identifier: Apache-2.0 -import logging - from contextlib import contextmanager from sqlalchemy import text from typing import Any @@ -13,7 +11,6 @@ from datacube.drivers.postgres._api import PostgresDbAPI from datacube.index.abstract import AbstractTransaction -_LOG = logging.getLogger(__name__) class PostgresTransaction(AbstractTransaction): def __init__(self, db: PostgresDb, idx_id: str) -> None: diff --git a/datacube/index/postgres/_users.py b/datacube/index/postgres/_users.py index ae6ae26c9f..2dd4c81877 100644 --- a/datacube/index/postgres/_users.py +++ b/datacube/index/postgres/_users.py @@ -9,7 +9,10 @@ class UserResource(AbstractUserResource, IndexResourceAddIn): - def __init__(self, db: PostgresDb, index: "datacube.index.postgres.index.Index") -> None: + def __init__(self, + db: PostgresDb, + index: "datacube.index.postgres.index.Index" # noqa: F821 + ) -> None: """ :type db: datacube.drivers.postgres._connections.PostgresDb """ diff --git a/datacube/index/postgres/index.py b/datacube/index/postgres/index.py index 8832e1e621..d56016bf6b 100644 --- a/datacube/index/postgres/index.py +++ b/datacube/index/postgres/index.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging from contextlib import contextmanager -from typing import Any from datacube.drivers.postgres import PostgresDb, PostgresDbAPI from datacube.index.postgres._transaction import PostgresTransaction diff --git a/integration_tests/index/test_index_data.py b/integration_tests/index/test_index_data.py index 1fe81cb5cd..a205a3836e 100755 --- a/integration_tests/index/test_index_data.py +++ b/integration_tests/index/test_index_data.py @@ -16,7 +16,6 @@ import pytest from dateutil import tz -from datacube.drivers.postgres import PostgresDb from datacube.index.exceptions import MissingRecordError from datacube.index import Index from datacube.model import Dataset, MetadataType @@ -279,8 +278,6 @@ def test_transactions_api_ctx_mgr(index, eo3_ls8_dataset_doc, eo3_ls8_dataset2_doc): from datacube.index.hl import Doc2Dataset - import logging - _LOG = logging.getLogger(__name__) resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) ds1, err = resolver(*eo3_ls8_dataset_doc) ds2, err = resolver(*eo3_ls8_dataset2_doc) @@ -311,8 +308,6 @@ def test_transactions_api_manual(index, eo3_ls8_dataset_doc, eo3_ls8_dataset2_doc): from datacube.index.hl import Doc2Dataset - import logging - _LOG = logging.getLogger(__name__) resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) ds1, err = resolver(*eo3_ls8_dataset_doc) ds2, err = resolver(*eo3_ls8_dataset2_doc) @@ -339,8 +334,6 @@ def test_transactions_api_hybrid(index, eo3_ls8_dataset_doc, eo3_ls8_dataset2_doc): from datacube.index.hl import Doc2Dataset - import logging - _LOG = logging.getLogger(__name__) resolver = Doc2Dataset(index, products=[ls8_eo3_product.name], verify_lineage=False) ds1, err = resolver(*eo3_ls8_dataset_doc) ds2, err = resolver(*eo3_ls8_dataset2_doc) diff --git a/integration_tests/index/test_search_legacy.py b/integration_tests/index/test_search_legacy.py index 07e4641f29..d04ee2833d 100644 --- a/integration_tests/index/test_search_legacy.py +++ b/integration_tests/index/test_search_legacy.py @@ -18,7 +18,6 @@ from psycopg2._range import NumericRange from datacube.config import LocalConfig -from datacube.drivers.postgres import PostgresDb from datacube.drivers.postgres._connections import DEFAULT_DB_USER from datacube.index import Index from datacube.model import Dataset From 59439cea499a88b883a14ee10ff5fe74207ab5a2 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 23 Sep 2022 10:36:58 +1000 Subject: [PATCH 17/18] whats_new.rst --- docs/about/whats_new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/about/whats_new.rst b/docs/about/whats_new.rst index 1f1a66626f..77a30ccc0e 100644 --- a/docs/about/whats_new.rst +++ b/docs/about/whats_new.rst @@ -37,6 +37,8 @@ v1.8.next - Implement `patch_url` argument to `dc.load()` and `dc.load_data()` to provide a way to sign dataset URIs, as is required to access some commercial archives (e.g. Microsoft Planetary Computer). API is based on the `odc-stac` implementation. Only works for direct loading. More work required for deferred (i.e. Dask) loading. (:pull: `1317`) +- Implement public-facing index-driver-independent API for managing database transactions, as per Enhancement Proposal + EP07 (:pull: `1318`) v1.8.7 (7 June 2022) From de53382ebaaaa18f7e2473a287d8f661e54fc1c5 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 23 Sep 2022 11:23:55 +1000 Subject: [PATCH 18/18] Comment corrections. --- datacube/drivers/postgis/_connections.py | 2 +- datacube/drivers/postgres/_connections.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datacube/drivers/postgis/_connections.py b/datacube/drivers/postgis/_connections.py index bbc8745205..7abbc604ca 100755 --- a/datacube/drivers/postgis/_connections.py +++ b/datacube/drivers/postgis/_connections.py @@ -259,7 +259,7 @@ def _connect(self): as some servers will aggressively close idle connections (eg. DEA's NCI servers). It also prevents the connection from being reused while borrowed. - Low level context manager, user self._db_connection instead + Low level context manager, use ._db_connection instead """ with self._engine.connect() as connection: try: diff --git a/datacube/drivers/postgres/_connections.py b/datacube/drivers/postgres/_connections.py index 130cef55f3..3fc5ed7de0 100755 --- a/datacube/drivers/postgres/_connections.py +++ b/datacube/drivers/postgres/_connections.py @@ -221,7 +221,7 @@ def _connect(self): as some servers will aggressively close idle connections (eg. DEA's NCI servers). It also prevents the connection from being reused while borrowed. - Low level context manager, user self._db_connection instead + Low level context manager, use ._db_connection instead """ with self._engine.connect() as connection: try: