Skip to content

Commit

Permalink
IAM authentication for Postgres (#1168)
Browse files Browse the repository at this point in the history
* Utility to function to obtain IAM auth token.

* Hook IAM auth into engine instantiation.

* Return last token if not getting a new one.

* Cleanup coding style.

* Cleanup coding style.

* Cleanup coding style - how deep an indent does it want?

* Pass region name in test.

* s/region/region_name/  oops.

* Pass region name when creating boto3 client.

* Fix obtain_new_iam_auth_token() function signature.

* Move iam config into config layer.

* Make options and env var names more consistent.

* Test environment parser.

* Removed hard-to-test and unnecessary exception conversion.

* Add documentation.

* Update whats_new.rst

Co-authored-by: phaesler <paul.haesler@data61.csiro.au>
  • Loading branch information
SpacemanPaul and phaesler committed Aug 5, 2021
1 parent 5076f09 commit 76eb350
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 18 deletions.
28 changes: 20 additions & 8 deletions datacube/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,32 @@ def split2(s: str, separator: str) -> Tuple[str, str]:

def parse_env_params() -> Dict[str, str]:
"""
- Read DATACUBE_IAM_* environment variables.
- Extract parameters from DATACUBE_DB_URL if present
- Else look for DB_HOSTNAME, DB_USERNAME, DB_PASSWORD, DB_DATABASE
- Return {} otherwise
"""

# Handle environment vars that cannot fit in the DB URL
non_url_params = {}
iam_auth = os.environ.get('DATACUBE_IAM_AUTHENTICATION')
if iam_auth is not None and iam_auth.lower() in ['y', 'yes']:
non_url_params["iam_authentication"] = True
iam_auth_timeout = os.environ.get('DATACUBE_IAM_TIMEOUT')
if iam_auth_timeout:
non_url_params["iam_timeout"] = int(iam_auth_timeout)

# Handle environment vars that may fit in the DB URL
db_url = os.environ.get('DATACUBE_DB_URL', None)
if db_url is not None:
return parse_connect_url(db_url)

params = {k: os.environ.get('DB_{}'.format(k.upper()), None)
for k in DB_KEYS}
return {k: v
for k, v in params.items()
if v is not None and v != ""}
params = parse_connect_url(db_url)
else:
raw_params = {k: os.environ.get('DB_{}'.format(k.upper()), None)
for k in DB_KEYS}
params = {k: v
for k, v in raw_params.items()
if v is not None and v != ""}
params.update(non_url_params)
return params


def _cfg_from_env_opts(opts: Dict[str, str],
Expand Down
49 changes: 42 additions & 7 deletions datacube/drivers/postgres/_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
import os
import re
from contextlib import contextmanager
from typing import Optional
from time import clock_gettime, CLOCK_REALTIME
from typing import Callable, Optional, Union

from sqlalchemy import create_engine, text
from sqlalchemy import event, create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL as EngineUrl

import datacube
from datacube.index.exceptions import IndexSetupError
from datacube.utils import jsonify_document
from datacube.utils.aws import obtain_new_iam_auth_token

from . import _api
from . import _core

Expand All @@ -40,6 +44,8 @@
# No default on Windows and some other systems
DEFAULT_DB_USER = None
DEFAULT_DB_PORT = 5432
DEFAULT_IAM_AUTH = False
DEFAULT_IAM_TIMEOUT = 600


class PostgresDb(object):
Expand Down Expand Up @@ -74,12 +80,18 @@ def from_config(cls, config, application_name=None, validate_connection=True):
config.get('db_port', DEFAULT_DB_PORT),
application_name=app_name,
validate=validate_connection,
pool_timeout=int(config.get('db_connection_timeout', 60))
iam_rds_auth=bool(config.get("db_iam_authentication", DEFAULT_IAM_AUTH)),
iam_rds_timeout=int(config.get("db_iam_timeout", DEFAULT_IAM_TIMEOUT)),
pool_timeout=int(config.get('db_connection_timeout', 60)),
# pass config?
)

@classmethod
def create(cls, hostname, database, username=None, password=None, port=None,
application_name=None, validate=True, pool_timeout=60):
application_name=None, validate=True,
iam_rds_auth=False, iam_rds_timeout=600,
# pass config?
pool_timeout=60):
mk_url = getattr(EngineUrl, 'create', EngineUrl)
engine = cls._create_engine(
mk_url(
Expand All @@ -88,6 +100,8 @@ def create(cls, hostname, database, username=None, password=None, port=None,
username=username, password=password,
),
application_name=application_name,
iam_rds_auth=iam_rds_auth,
iam_rds_timeout=iam_rds_timeout,
pool_timeout=pool_timeout)
if validate:
if not _core.database_exists(engine):
Expand All @@ -104,16 +118,15 @@ def create(cls, hostname, database, username=None, password=None, port=None,
return PostgresDb(engine)

@staticmethod
def _create_engine(url, application_name=None, pool_timeout=60):
return create_engine(
def _create_engine(url, application_name=None, iam_rds_auth=False, iam_rds_timeout=600, pool_timeout=60):
engine = create_engine(
url,
echo=False,
echo_pool=False,

# 'AUTOCOMMIT' here means READ-COMMITTED isolation level with autocommit on.
# When a transaction is needed we will do an explicit begin/commit.
isolation_level='AUTOCOMMIT',

json_serializer=_to_json,
# If a connection is idle for this many seconds, SQLAlchemy will renew it rather
# than assuming it's still open. Allows servers to close idle connections without clients
Expand All @@ -122,6 +135,11 @@ def _create_engine(url, application_name=None, pool_timeout=60):
connect_args={'application_name': application_name}
)

if iam_rds_auth:
handle_dynamic_token_authentication(engine, obtain_new_iam_auth_token, timeout=iam_rds_timeout, url=url)

return engine

@property
def url(self) -> EngineUrl:
return self._engine.url
Expand Down Expand Up @@ -244,6 +262,23 @@ def __repr__(self):
return "PostgresDb<engine={!r}>".format(self._engine)


def handle_dynamic_token_authentication(engine: Engine,
new_token: Callable[..., str],
timeout: Union[float, int] = 600,
**kwargs) -> None:
last_token = [None]
last_token_time = [0.0]

@event.listens_for(engine, "do_connect")
def override_new_connection(dialect, conn_rec, cargs, cparams):
# Handle IAM authentication
now = clock_gettime(CLOCK_REALTIME)
if now - last_token_time[0] > timeout:
last_token[0] = new_token(**kwargs)
last_token_time[0] = now
cparams["password"] = last_token[0]


def _to_json(o):
# Postgres <=9.5 doesn't support NaN and Infinity
fixedup = jsonify_document(o)
Expand Down
12 changes: 12 additions & 0 deletions datacube/utils/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import time
from urllib.request import urlopen
from urllib.parse import urlparse
from sqlalchemy.engine.url import URL

from typing import Optional, Dict, Tuple, Any, Union, IO
from datacube.utils.generic import thread_local_cache
from ..rio import configure_s3_access
Expand Down Expand Up @@ -438,3 +440,13 @@ def get_aws_settings(profile: Optional[str] = None,
aws_secret_access_key=cc.secret_key,
aws_session_token=cc.token,
requester_pays=requester_pays), creds)


def obtain_new_iam_auth_token(url: URL, region_name: str = "auto", profile_name: Optional[str] = None) -> str:
# Boto3 is not core requirement, but ImportError is probably the right exception to throw anyway.
from boto3.session import Session as Boto3Session

session = Boto3Session(profile_name=profile_name)
client = session.client("rds", region_name=region_name)
return client.generate_db_auth_token(DBHostname=url.host, Port=url.port, DBUsername=url.username,
Region=region_name)
1 change: 1 addition & 0 deletions docs/about/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ v1.8.4 (???)
- Add new ``dataset_predicate`` param to ``dc.load`` and ``dc.find_datasets`` for more flexible temporal filtering (e.g. loading data for non-contiguous time ranges such as specific months or seasons over multiple years). (:pull:`1148`, :pull:`1156`)
- Fix to ``GroupBy`` to ensure output output axes are correctly labelled when sorting observations using ``sort_key`` (:pull:`1157`)
- ``GroupBy`` is now its own class to allow easier custom grouping and sorting of data (:pull:`1157`)
- add support for IAM authentication for RDS databases in AWS. (:pull:`1168`)

.. _`notebook examples`: https://github.com/GeoscienceAustralia/dea-notebooks/

Expand Down
27 changes: 25 additions & 2 deletions docs/ops/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ Example:
[staging]
db_hostname: staging.dea.ga.gov.au
## An AWS environment, using RDS ##
[aws_rds]
# Point database to an RDS server on AWS.
db_hostname: your.rds.server.name
db_username: your_rds_username
db_database: your_rds_db
# Choose an authentication option:
# 1. password authentication, as documented above
# db_password: Ungue55able$ecRet
# 2. IAM Authentication
# iam_authentication: yes
#
# Token timeout in seconds. Defaults to 600 (10 minutes)
# iam_timeout: 750
Note that the ``staging`` environment only specifies the hostname, all other
fields will use default values (database ``datacube``, current username,
Expand Down Expand Up @@ -98,7 +115,7 @@ It is possible to configure datacube with a single environment variable:
inside a docker image. The format of the URL is the same as used by SQLAclchemy:
``postgresql://user:password@host:port/database``. Only the ``database`` parameter
is required. Note that ``password`` is url encoded, so it can contain special
characters.
characters.

For more information refer to the `SQLAlchemy database URLs documentation
<https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls>`_.
Expand All @@ -118,9 +135,15 @@ Examples:


It is also possible to use separate environment variables for each component of
the connection URL. The recognised environment variables are
the connection URL. The recognised environment variables are
``DB_HOSTNAME``, ``DB_PORT``, ``DB_USERNAME``, ``DB_PASSWORD`` and ``DB_DATABASE``.

AWS IAM authentication for RDS can also be activated by setting the
``DATACUBE_IAM_AUTHENTICATION`` environment variable to ``'y'`` or ``'yes'``.
The IAM token timeout can be tuned by setting the ``DATACUBE_IAM_TIMEOUT``
environment variable to a value in seconds. Default is 600 (i.e. 10 minutes).


Types of Indexes
================

Expand Down
8 changes: 7 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def _clear_cfg_env(monkeypatch):
'DB_HOSTNAME',
'DB_PORT',
'DB_USERNAME',
'DB_PASSWORD'):
'DB_PASSWORD',
'DATACUBE_IAM_AUTHENTICATION',
'DATACUBE_IAM_TIMEOUT'):
monkeypatch.delenv(e, raising=False)


Expand All @@ -140,6 +142,10 @@ def check_env(**kw):
return parse_env_params()

assert check_env() == {}
assert check_env(DATACUBE_IAM_AUTHENTICATION="yes",
DATACUBE_IAM_TIMEOUT='666') == dict(
iam_authentication=True,
iam_timeout=666)
assert check_env(DATACUBE_DB_URL='postgresql:///db') == dict(
hostname='',
database='db'
Expand Down
35 changes: 35 additions & 0 deletions tests/test_dynamic_db_passwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# This file is part of the Open Data Cube, see https://opendatacube.org for more information
#
# Copyright (c) 2015-2020 ODC Contributors
# SPDX-License-Identifier: Apache-2.0
import pytest
from sqlalchemy.exc import OperationalError
from sqlalchemy.engine.url import URL

from datacube.drivers.postgres._connections import PostgresDb, handle_dynamic_token_authentication


counter = [0]
last_base = [None]


def next_token(base):
counter[0] = counter[0] + 1
last_base[0] = base
return f"{base}{counter[0]}"


def test_dynamic_password():
url = URL.create(
'postgresql',
host="fake_host", database="fake_database", port=6543,
username="fake_username", password="fake_password"
)
engine = PostgresDb._create_engine(url)
counter[0] = 0
last_base[0] = None
handle_dynamic_token_authentication(engine, next_token, base="password")
with pytest.raises(OperationalError):
conn = engine.connect()
assert counter[0] == 1
assert last_base[0] == "password"
17 changes: 17 additions & 0 deletions tests/test_utils_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
s3_fetch,
s3_head_object,
_s3_cache_key,
obtain_new_iam_auth_token,
)


Expand Down Expand Up @@ -255,3 +256,19 @@ def test_s3_client_cache(monkeypatch, without_aws_env):

keys = set(_s3_cache_key(**o) for o in opts)
assert len(keys) == len(opts)


def test_obtain_new_iam_token(monkeypatch, without_aws_env):
import moto
from sqlalchemy.engine.url import URL
url = URL.create(
'postgresql',
host="fakehost", database="fake_db", port=5432,
username="fakeuser", password="definitely_a_fake_password",
)

monkeypatch.setenv("AWS_ACCESS_KEY_ID", "fake-key-id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "fake-secret")
with moto.mock_iam():
token = obtain_new_iam_auth_token(url, region_name='us-west-1')
assert isinstance(token, str)

0 comments on commit 76eb350

Please sign in to comment.