From d6b65d1b843c2485ed223fd946b80332276c40bd Mon Sep 17 00:00:00 2001 From: Jelte Fennema Date: Wed, 28 Jun 2023 09:30:56 +0200 Subject: [PATCH] Support replication connections through PgBouncer In session pooling mode PgBouncer is pretty much a transparent proxy, i.e. the client does normally not even need to know that PgBouncer is in the middle. This allows things like load balancing and failovers without the client needing to know about this at all. But as soon as replication connections are needed, this was not possible anymore, because PgBouncer would reject those instead of proxying them to the right server. This PR fixes that by also proxying replication connections. They are handled pretty differently from normal connections though. A client and server replication connection will form a strong pair, as soon as one is closed the other is closed too. So, there's no caching of the server replication connections, like is done for regular connections. Reusing replication connections comes with a ton of gotchas. Postgres will throw errors in many cases when trying to do so. So simply not doing it seems like a good tradeoff for ease of implementation. Especially because replication connections are pretty much always very long lived. So re-using them gains pretty much no performance benefits. Fixes #382 --- .cirrus.yml | 6 +- .editorconfig | 10 ++ Makefile | 2 + doc/config.md | 2 +- include/bouncer.h | 16 ++ include/common/builtins.h | 17 ++ include/common/postgres_compat.h | 3 +- include/hba.h | 2 +- include/objects.h | 1 + include/pktbuf.h | 3 - include/server.h | 1 + include/util.h | 2 + include/varcache.h | 1 + src/client.c | 83 ++++++++- src/common/bool.c | 112 +++++++++++++ src/hba.c | 21 ++- src/janitor.c | 25 ++- src/loader.c | 14 -- src/objects.c | 131 ++++++++++++--- src/proto.c | 72 +++++++- src/server.c | 30 +++- src/util.c | 15 ++ src/varcache.c | 15 ++ test/conftest.py | 1 + test/hba_test.c | 10 +- test/hba_test.eval | 13 ++ test/hba_test.rules | 5 + test/ssl/test.ini | 1 + test/test.ini | 10 ++ test/test_limits.py | 5 +- test/test_replication.py | 280 +++++++++++++++++++++++++++++++ test/test_ssl.py | 38 +++++ test/utils.py | 211 +++++++++++++++++++++-- 33 files changed, 1057 insertions(+), 101 deletions(-) create mode 100644 include/common/builtins.h create mode 100644 src/common/bool.c create mode 100644 test/test_replication.py diff --git a/.cirrus.yml b/.cirrus.yml index 438db2c10f36..3d58a995af8d 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -103,7 +103,7 @@ task: # - image: rockylinux:8 # - image: centos:centos7 # setup_script: -# - yum -y install autoconf automake diffutils file libevent-devel libtool make openssl-devel pkg-config postgresql-server systemd-devel wget +# - yum -y install autoconf automake diffutils file libevent-devel libtool make openssl-devel pkg-config postgresql-server postgresql-contrib systemd-devel wget # - if cat /etc/centos-release | grep -q ' 7'; then yum -y install python python-pip; else yum -y install python3 python3-pip sudo iptables; fi # - wget -O /tmp/pandoc.tar.gz https://github.com/jgm/pandoc/releases/download/2.10.1/pandoc-2.10.1-linux-amd64.tar.gz # - tar xvzf /tmp/pandoc.tar.gz --strip-components 1 -C /usr/local/ @@ -133,7 +133,7 @@ task: # - image: alpine:latest # setup_script: # - apk update -# - apk add autoconf automake bash build-base libevent-dev libtool openssl openssl-dev pkgconf postgresql python3 py3-pip wget sudo iptables +# - apk add autoconf automake bash build-base libevent-dev libtool openssl openssl-dev pkgconf postgresql postgresql-contrib python3 py3-pip wget sudo iptables # - wget -O /tmp/pandoc.tar.gz https://github.com/jgm/pandoc/releases/download/2.10.1/pandoc-2.10.1-linux-amd64.tar.gz # - tar xvzf /tmp/pandoc.tar.gz --strip-components 1 -C /usr/local/ # - python3 -m pip install -r requirements.txt @@ -161,7 +161,7 @@ task: HAVE_IPV6_LOCALHOST: yes USE_SUDO: true setup_script: - - pkg install -y autoconf automake bash gmake hs-pandoc libevent libtool pkgconf postgresql12-server python devel/py-pip sudo + - pkg install -y autoconf automake bash gmake hs-pandoc libevent libtool pkgconf postgresql12-server postgresql12-contrib python devel/py-pip sudo - pip install -r requirements.txt - kldload pf - echo 'anchor "pgbouncer_test/*"' >> /etc/pf.conf diff --git a/.editorconfig b/.editorconfig index 4c8a713d613e..678f626abaf2 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,16 @@ trim_trailing_whitespace = true indent_style = tab indent_size = 8 +[hba_test.{eval,rules}] +indent_style = tab +indent_size = 8 + +[hba_test.rules] +# Disable trailing_whitespace check for hba_test.rules, because one of the +# tests in that file is that parsing doesn't break in case of trailing +# whitespace. +trim_trailing_whitespace = false + [*.py] indent_style = space indent_size = 4 diff --git a/Makefile b/Makefile index 5158c2af3f3d..68f26d063c69 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ pgbouncer_SOURCES = \ src/util.c \ src/varcache.c \ src/common/base64.c \ + src/common/bool.c \ src/common/saslprep.c \ src/common/scram-common.c \ src/common/unicode_norm.c \ @@ -51,6 +52,7 @@ pgbouncer_SOURCES = \ include/util.h \ include/varcache.h \ include/common/base64.h \ + include/common/builtins.h \ include/common/pg_wchar.h \ include/common/postgres_compat.h \ include/common/saslprep.h \ diff --git a/doc/config.md b/doc/config.md index 7339246b3277..8aab1ce9efe8 100644 --- a/doc/config.md +++ b/doc/config.md @@ -1285,7 +1285,7 @@ The file follows the format of the PostgreSQL `pg_hba.conf` file (see ). * Supported record types: `local`, `host`, `hostssl`, `hostnossl`. -* Database field: Supports `all`, `sameuser`, `@file`, multiple names. Not supported: `replication`, `samerole`, `samegroup`. +* Database field: Supports `all`, `replication`, `sameuser`, `@file`, multiple names. Not supported: `samerole`, `samegroup`. * User name field: Supports `all`, `@file`, multiple names. Not supported: `+groupname`. * Address field: Supports IPv4, IPv6. Not supported: DNS names, domain prefixes. * Auth-method field: Only methods supported by PgBouncer's `auth_type` diff --git a/include/bouncer.h b/include/bouncer.h index eee300887967..1a88a02565b4 100644 --- a/include/bouncer.h +++ b/include/bouncer.h @@ -104,6 +104,7 @@ typedef union PgAddr PgAddr; typedef enum SocketState SocketState; typedef struct PktHdr PktHdr; typedef struct ScramState ScramState; +typedef enum ReplicationType ReplicationType; extern int cf_sbuf_len; @@ -305,6 +306,11 @@ struct PgPool { * Clients that sent cancel request, to cancel another client its query. * These requests are waiting for a new server connection to be opened, * before the request can be forwarded. + * + * This is a separate list from waiting_client_list, because we want to + * give cancel requests priority over regular clients. The main reason + * for this is, because a cancel request might free up a connection, + * which can be used for one of the waiting clients. */ struct StatList waiting_cancel_req_list; @@ -504,6 +510,13 @@ struct PgDatabase { struct AATree user_tree; /* users that have been queried on this database */ }; +enum ReplicationType { + REPLICATION_NONE = 0, + REPLICATION_LOGICAL, + REPLICATION_PHYSICAL, +}; + +extern const char *replication_type_parameters[3]; /* * A client or server connection. @@ -543,6 +556,9 @@ struct PgSocket { bool wait_sslchar : 1; /* server: waiting for ssl response: S/N */ + ReplicationType replication; /* If this is a replication connection */ + char *startup_options; /* only tracked for replication connections */ + int expect_rfq_count; /* client: count of ReadyForQuery packets client should see */ usec_t connect_time; /* when connection was made */ diff --git a/include/common/builtins.h b/include/common/builtins.h new file mode 100644 index 000000000000..aede401a94db --- /dev/null +++ b/include/common/builtins.h @@ -0,0 +1,17 @@ +/*------------------------------------------------------------------------- + * + * builtins.h + * Declarations for operations on built-in types. + * + * + * Portions Copyright (c) 1996-2023, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * include/common/builtins.h + * + *------------------------------------------------------------------------- + */ + +/* bool.c */ +extern bool parse_bool(const char *value, bool *result); +extern bool parse_bool_with_len(const char *value, size_t len, bool *result); diff --git a/include/common/postgres_compat.h b/include/common/postgres_compat.h index 71afe25192af..98267fba5095 100644 --- a/include/common/postgres_compat.h +++ b/include/common/postgres_compat.h @@ -7,6 +7,7 @@ /* from c.h */ #include +#include #define int8 int8_t #define uint8 uint8_t @@ -15,6 +16,7 @@ #define lengthof(array) (sizeof (array) / sizeof ((array)[0])) #define pg_hton32(x) htobe32(x) +#define pg_strncasecmp strncasecmp #define pg_attribute_noreturn() _NORETURN @@ -30,6 +32,5 @@ #define pg_sha256_update(ctx, data, len) sha256_update(ctx, data, len) #define pg_sha256_final(ctx, dst) sha256_final(ctx, dst) - /* define this to use non-server code paths */ #define FRONTEND diff --git a/include/hba.h b/include/hba.h index 7e8c3bb752f6..fc92b96bf7d6 100644 --- a/include/hba.h +++ b/include/hba.h @@ -20,4 +20,4 @@ struct HBA; struct HBA *hba_load_rules(const char *fn); void hba_free(struct HBA *hba); -int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, const char *dbname, const char *username); +int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, ReplicationType replication, const char *dbname, const char *username); diff --git a/include/objects.h b/include/objects.h index 1a7f7aa535cf..db1432aa8285 100644 --- a/include/objects.h +++ b/include/objects.h @@ -42,6 +42,7 @@ PgPool *get_pool(PgDatabase *, PgUser *); PgPool *get_peer_pool(PgDatabase *); PgSocket *compare_connections_by_time(PgSocket *lhs, PgSocket *rhs); bool evict_connection(PgDatabase *db) _MUSTCHECK; +bool evict_pool_connection(PgPool *pool) _MUSTCHECK; bool evict_user_connection(PgUser *user) _MUSTCHECK; bool find_server(PgSocket *client) _MUSTCHECK; bool life_over(PgSocket *server); diff --git a/include/pktbuf.h b/include/pktbuf.h index 7c51a0ff3e6c..35c8c26cd35f 100644 --- a/include/pktbuf.h +++ b/include/pktbuf.h @@ -97,9 +97,6 @@ void pktbuf_write_ExtQuery(PktBuf *buf, const char *query, int nargs, ...); #define pktbuf_write_CancelRequest(buf, key) \ pktbuf_write_generic(buf, PKT_CANCEL, "b", key, 8) -#define pktbuf_write_StartupMessage(buf, user, parms, parms_len) \ - pktbuf_write_generic(buf, PKT_STARTUP, "bsss", parms, parms_len, "user", user, "") - #define pktbuf_write_PasswordMessage(buf, psw) \ pktbuf_write_generic(buf, 'p', "s", psw) diff --git a/include/server.h b/include/server.h index bb6195a5fbaf..7ca791a8a44c 100644 --- a/include/server.h +++ b/include/server.h @@ -18,6 +18,7 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *pkt) _MUSTCHECK; void kill_pool_logins(PgPool *pool, const char *sqlstate, const char *msg); +int server_pool_mode(PgSocket *server) _MUSTCHECK; int pool_pool_mode(PgPool *pool) _MUSTCHECK; int pool_pool_size(PgPool *pool) _MUSTCHECK; int pool_min_pool_size(PgPool *pool) _MUSTCHECK; diff --git a/include/util.h b/include/util.h index 1a979b010cc2..93d419200fa7 100644 --- a/include/util.h +++ b/include/util.h @@ -70,3 +70,5 @@ bool cf_set_authdb(struct CfValue *cv, const char *value); /* reserved database name checking */ bool check_reserved_database(const char *value); + +bool strings_equal(const char *str_left, const char *str_right) _MUSTCHECK; diff --git a/include/varcache.h b/include/varcache.h index 21599666a900..eb4caa2532ff 100644 --- a/include/varcache.h +++ b/include/varcache.h @@ -19,6 +19,7 @@ void init_var_lookup(const char *cf_track_extra_parameters); int get_num_var_cached(void); bool varcache_set(VarCache *cache, const char *key, const char *value) /* _MUSTCHECK */; bool varcache_apply(PgSocket *server, PgSocket *client, bool *changes_p) _MUSTCHECK; +void varcache_apply_startup(PktBuf *pkt, PgSocket *client); void varcache_fill_unset(VarCache *src, PgSocket *dst); void varcache_clean(VarCache *cache); void varcache_add_params(PktBuf *pkt, VarCache *vars); diff --git a/src/client.c b/src/client.c index b37483ca358e..1186d351e6ad 100644 --- a/src/client.c +++ b/src/client.c @@ -23,6 +23,7 @@ #include "bouncer.h" #include "pam.h" #include "scram.h" +#include "common/builtins.h" #include @@ -267,11 +268,15 @@ static bool finish_set_pool(PgSocket *client, bool takeover) if (client->sbuf.tls) { char infobuf[96] = ""; tls_get_connection_info(client->sbuf.tls, infobuf, sizeof infobuf); - slog_info(client, "login attempt: db=%s user=%s tls=%s", - client->db->name, client->login_user->name, infobuf); + slog_info(client, "login attempt: db=%s user=%s tls=%s replication=%s", + client->db->name, + client->login_user->name, + infobuf, + replication_type_parameters[client->replication]); } else { - slog_info(client, "login attempt: db=%s user=%s tls=no", - client->db->name, client->login_user->name); + slog_info(client, "login attempt: db=%s user=%s tls=no replication=%s", + client->db->name, client->login_user->name, + replication_type_parameters[client->replication]); } } @@ -288,8 +293,13 @@ static bool finish_set_pool(PgSocket *client, bool takeover) auth = cf_auth_type; if (auth == AUTH_HBA) { - auth = hba_eval(parsed_hba, &client->remote_addr, !!client->sbuf.tls, - client->db->name, client->login_user->name); + auth = hba_eval( + parsed_hba, + &client->remote_addr, + !!client->sbuf.tls, + client->replication, + client->db->name, + client->login_user->name); } if (auth == AUTH_MD5) { @@ -595,6 +605,25 @@ static bool set_startup_options(PgSocket *client, const char *options) char arg_buf[400]; struct MBuf arg; const char *position = options; + + if (client->replication) { + /* + * Since replication clients will be bound 1-to-1 to a server + * connection, we can support any configuration flags and + * fields in the options startup parameter. Because we can + * simply send the exact same value for the options parameter + * when opening the replication connection to the server. This + * allows us to also support GUCs that don't have the + * GUC_REPORT flag, specifically extra_float_digits which is a + * configuration that is set by CREATE SUBSCRIPTION in the + * options parameter. + */ + client->startup_options = strdup(options); + if (!client->startup_options) + disconnect_client(client, true, "out of memory"); + return true; + } + mbuf_init_fixed_writer(&arg, arg_buf, sizeof(arg_buf)); slog_debug(client, "received options: %s", options); @@ -658,12 +687,52 @@ static void set_appname(PgSocket *client, const char *app_name) } } +/* + * set_replication sets the replication field on the client according the given + * replicationString. + */ +static bool set_replication(PgSocket *client, const char *replicationString) +{ + bool replicationBool = false; + if (strcmp(replicationString, "database") == 0) { + client->replication = REPLICATION_LOGICAL; + return true; + } + if (!parse_bool(replicationString, &replicationBool)) { + return false; + } + client->replication = replicationBool ? REPLICATION_PHYSICAL : REPLICATION_NONE; + return true; +} + static bool decide_startup_pool(PgSocket *client, PktHdr *pkt) { const char *username = NULL, *dbname = NULL; const char *key, *val; bool ok; bool appname_found = false; + unsigned original_read_pos = pkt->data.read_pos; + + /* + * First check if we're dealing with a replication connection. Because for + * those we support some additional things when parsing the startup + * parameters, specifically we support any arguments in the options startup + * packet. + */ + while (1) { + ok = mbuf_get_string(&pkt->data, &key); + if (!ok || *key == 0) + break; + ok = mbuf_get_string(&pkt->data, &val); + if (!ok) + break; + if (strcmp(key, "replication") == 0) { + slog_debug(client, "got var: %s=%s", key, val); + set_replication(client, val); + } + } + + pkt->data.read_pos = original_read_pos; while (1) { ok = mbuf_get_string(&pkt->data, &key); @@ -685,6 +754,8 @@ static bool decide_startup_pool(PgSocket *client, PktHdr *pkt) } else if (strcmp(key, "application_name") == 0) { set_appname(client, val); appname_found = true; + } else if (strcmp(key, "replication") == 0) { + /* do nothing, already checked in the previous loop */ } else if (varcache_set(&client->vars, key, val)) { slog_debug(client, "got var: %s=%s", key, val); } else if (strlist_contains(cf_ignore_startup_params, key)) { diff --git a/src/common/bool.c b/src/common/bool.c new file mode 100644 index 000000000000..b0543e1c58d2 --- /dev/null +++ b/src/common/bool.c @@ -0,0 +1,112 @@ +/*------------------------------------------------------------------------- + * + * bool.c + * Functions for the built-in type "bool". + * + * Portions Copyright (c) 1996-2023, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/utils/adt/bool.c + * + *------------------------------------------------------------------------- + */ + +#include "common/postgres_compat.h" + +#include "common/builtins.h" + +/* + * Try to interpret value as boolean value. Valid values are: true, + * false, yes, no, on, off, 1, 0; as well as unique prefixes thereof. + * If the string parses okay, return true, else false. + * If okay and result is not NULL, return the value in *result. + */ +bool +parse_bool(const char *value, bool *result) +{ + return parse_bool_with_len(value, strlen(value), result); +} + +bool +parse_bool_with_len(const char *value, size_t len, bool *result) +{ + switch (*value) + { + case 't': + case 'T': + if (pg_strncasecmp(value, "true", len) == 0) + { + if (result) + *result = true; + return true; + } + break; + case 'f': + case 'F': + if (pg_strncasecmp(value, "false", len) == 0) + { + if (result) + *result = false; + return true; + } + break; + case 'y': + case 'Y': + if (pg_strncasecmp(value, "yes", len) == 0) + { + if (result) + *result = true; + return true; + } + break; + case 'n': + case 'N': + if (pg_strncasecmp(value, "no", len) == 0) + { + if (result) + *result = false; + return true; + } + break; + case 'o': + case 'O': + /* 'o' is not unique enough */ + if (pg_strncasecmp(value, "on", (len > 2 ? len : 2)) == 0) + { + if (result) + *result = true; + return true; + } + else if (pg_strncasecmp(value, "off", (len > 2 ? len : 2)) == 0) + { + if (result) + *result = false; + return true; + } + break; + case '1': + if (len == 1) + { + if (result) + *result = true; + return true; + } + break; + case '0': + if (len == 1) + { + if (result) + *result = false; + return true; + } + break; + default: + break; + } + + if (result) + *result = false; /* suppress compiler warning */ + return false; +} diff --git a/src/hba.c b/src/hba.c index 99047b41c979..5af965259483 100644 --- a/src/hba.c +++ b/src/hba.c @@ -31,8 +31,9 @@ enum RuleType { RULE_HOSTNOSSL, }; -#define NAME_ALL 1 -#define NAME_SAMEUSER 2 +#define NAME_ALL 1 +#define NAME_SAMEUSER 2 +#define NAME_REPLICATION 4 struct NameSlot { size_t strlen; @@ -399,8 +400,8 @@ static bool parse_names(struct HBAName *hname, struct TokParser *tp, bool is_db, return false; } if (eat_kw(tp, "replication")) { - log_warning("replication is not supported"); - return false; + hname->flags |= NAME_REPLICATION; + goto eat_comma; } } @@ -699,7 +700,7 @@ static bool match_inet6(const struct HBARule *rule, PgAddr *addr) (src[2] & mask[2]) == base[2] && (src[3] & mask[3]) == base[3]; } -int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, const char *dbname, const char *username) +int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, ReplicationType replication, const char *dbname, const char *username) { struct List *el; struct HBARule *rule; @@ -733,8 +734,14 @@ int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, const char *dbname, con } /* match db & user */ - if (!name_match(&rule->db_name, dbname, dbnamelen, username)) - continue; + if (replication == REPLICATION_PHYSICAL) { + if (!(rule->db_name.flags & NAME_REPLICATION)) { + continue; + } + } else { + if (!name_match(&rule->db_name, dbname, dbnamelen, username)) + continue; + } if (!name_match(&rule->user_name, username, unamelen, dbname)) continue; diff --git a/src/janitor.c b/src/janitor.c index 2616a3c3bb68..db7c40e689c7 100644 --- a/src/janitor.c +++ b/src/janitor.c @@ -187,7 +187,14 @@ static void per_loop_activate(PgPool *pool) sv_used = statlist_count(&pool->used_server_list); statlist_for_each_safe(item, &pool->waiting_client_list, tmp) { client = container_of(item, PgSocket, head); - if (!statlist_empty(&pool->idle_server_list)) { + if (client->replication) { + /* + * For replication connections we always launch + * a new connection, but we continue with the loop, + * because there might be normal clients waiting too. + */ + launch_new_connection(pool, /* evict_if_needed= */ true); + } else if (!statlist_empty(&pool->idle_server_list)) { /* db not fully initialized after reboot */ if (client->wait_for_welcome && !pool->welcome_msg_ready) { launch_new_connection(pool, /* evict_if_needed= */ true); @@ -243,7 +250,7 @@ static int per_loop_suspend(PgPool *pool, bool force_suspend) active += suspend_socket_list(&pool->active_client_list, force_suspend); - /* this list is not suspendable, but still need force_suspend and counting */ + /* these lists are not suspendable, but still need force_suspend and counting */ active += suspend_socket_list(&pool->waiting_client_list, force_suspend); if (active) per_loop_activate(pool); @@ -534,13 +541,13 @@ static void pool_server_maint(PgPool *pool) check_unused_servers(pool, &pool->idle_server_list, 1); /* disconnect close_needed active servers if server_fast_close is set */ - if (cf_server_fast_close) { - statlist_for_each_safe(item, &pool->active_server_list, tmp) { - server = container_of(item, PgSocket, head); - Assert(server->state == SV_ACTIVE); - if (server->ready && server->close_needed) - disconnect_server(server, true, "database configuration changed"); - } + statlist_for_each_safe(item, &pool->active_server_list, tmp) { + server = container_of(item, PgSocket, head); + Assert(server->state == SV_ACTIVE); + if (cf_server_fast_close && server->ready && server->close_needed) + disconnect_server(server, true, "database configuration changed"); + if (server->replication != REPLICATION_NONE && server->close_needed) + disconnect_server(server, true, "database configuration changed"); } /* handle query_timeout and idle_transaction_timeout */ diff --git a/src/loader.c b/src/loader.c index c76484568c64..56a9717b0160 100644 --- a/src/loader.c +++ b/src/loader.c @@ -120,20 +120,6 @@ static char * cstr_get_pair(char *p, return cstr_skip_ws(p); } -/* - * Same as strcmp, but handles NULLs. If both sides are NULL, returns "true". - */ -static bool strings_equal(const char *str_left, const char *str_right) -{ - if (str_left == NULL && str_right == NULL) - return true; - - if (str_left == NULL || str_right == NULL) - return false; - - return strcmp(str_left, str_right) == 0; -} - static bool set_auth_dbname(PgDatabase *db, const char *new_auth_dbname) { if (strings_equal(db->auth_dbname, new_auth_dbname)) diff --git a/src/objects.c b/src/objects.c index ce6c5b3cb80b..87b7435ef340 100644 --- a/src/objects.c +++ b/src/objects.c @@ -73,6 +73,12 @@ static STATLIST(justfree_server_list); /* init autodb idle list */ STATLIST(autodatabase_idle_list); +const char *replication_type_parameters[] = { + [REPLICATION_NONE] = "no", + [REPLICATION_LOGICAL] = "database", + [REPLICATION_PHYSICAL] = "yes", +}; + /* fast way to get number of active clients */ int get_active_client_count(void) { @@ -695,10 +701,12 @@ PgPool *get_peer_pool(PgDatabase *db) /* deactivate socket and put into wait queue */ static void pause_client(PgSocket *client) { + SocketState newstate; Assert(client->state == CL_ACTIVE || client->state == CL_LOGIN); - slog_debug(client, "pause_client"); - change_client_state(client, CL_WAITING); + newstate = CL_WAITING; + + change_client_state(client, newstate); if (!sbuf_pause(&client->sbuf)) disconnect_client(client, true, "pause failed"); } @@ -802,9 +810,22 @@ bool find_server(PgSocket *client) if (client->link) return true; + slog_debug(client, "find_server: no linked server yet"); + /* try to get idle server, if allowed */ if (cf_pause_mode == P_PAUSE || pool->db->db_paused) { server = NULL; + } else if (client->replication) { + /* + * For replication clients we open dedicated server connections. These + * connections are linked to a client as soon as the server is ready, + * instead of lazily being assigned to a client only when the client + * sends a query. So if we reach this point we know that that has not + * happened yet, and we need to create a new replication connection for + * this client. + */ + launch_new_connection(pool, /*evict_if_needed= */ true); + server = NULL; } else { while (1) { server = first_socket(&pool->idle_server_list); @@ -859,8 +880,11 @@ static bool reuse_on_release(PgSocket *server) { bool res = true; PgPool *pool = server->pool; - PgSocket *client = first_socket(&pool->waiting_client_list); - if (client) { + PgSocket *client; + Assert(!server->replication); + slog_debug(server, "reuse_on_release: replication %d", server->replication); + client = first_socket(&pool->waiting_client_list); + if (client && !client->replication) { activate_client(client); /* @@ -932,7 +956,7 @@ bool release_server(PgSocket *server) } if (*cf_server_reset_query && (cf_server_reset_query_always || - pool_pool_mode(pool) == POOL_SESSION)) { + server_pool_mode(server) == POOL_SESSION)) { /* notify reset is required */ newstate = SV_TESTED; } else if (cf_server_check_delay == 0 && *cf_server_check_query) { @@ -986,6 +1010,18 @@ bool release_server(PgSocket *server) return true; } + if (server->replication) { + if (server->link) { + slog_debug(server, "release_server: new replication connection ready"); + change_server_state(server, SV_ACTIVE); + activate_client(server->link); + return true; + } else { + disconnect_server(server, true, "replication client was closed"); + return false; + } + } + Assert(server->link == NULL); slog_noise(server, "release_server: new state=%d", newstate); change_server_state(server, newstate); @@ -1000,6 +1036,28 @@ bool release_server(PgSocket *server) return true; } +static void unlink_server(PgSocket *server, const char *reason) +{ + PgSocket *client; + if (!server->link) + return; + + client = server->link; + + client->link = NULL; + server->link = NULL; + /* + * Send reason to client if it is already + * logged in, otherwise send generic message. + */ + if (client->state == CL_ACTIVE || client->state == CL_WAITING) + disconnect_client(client, true, "%s", reason); + else if (client->state == CL_ACTIVE_CANCEL) + disconnect_client(client, false, "successfully sent cancel request"); + else + disconnect_client(client, true, "bouncer config error"); +} + /* * close server connection * @@ -1031,26 +1089,9 @@ void disconnect_server(PgSocket *server, bool send_term, const char *reason, ... switch (server->state) { case SV_ACTIVE_CANCEL: - case SV_ACTIVE: { - PgSocket *client = server->link; - - if (client) { - client->link = NULL; - server->link = NULL; - /* - * Send reason to client if it is already - * logged in, otherwise send generic message. - */ - if (client->state == CL_ACTIVE || client->state == CL_WAITING) - disconnect_client(client, true, "%s", reason); - else if (client->state == CL_ACTIVE_CANCEL) - disconnect_client(client, false, "successfully sent cancel request"); - else - disconnect_client(client, true, "bouncer config error"); - } - + case SV_ACTIVE: + unlink_server(server, reason); break; - } case SV_TESTED: case SV_USED: case SV_IDLE: @@ -1074,6 +1115,8 @@ void disconnect_server(PgSocket *server, bool send_term, const char *reason, ... server->pool->last_connect_failed = false; send_term = false; } + if (server->replication) + unlink_server(server, reason); break; default: fatal("bad server state: %d", server->state); @@ -1225,6 +1268,16 @@ void disconnect_client_sqlstate(PgSocket *client, bool notify, const char *sqlst break; case CL_WAITING: case CL_WAITING_LOGIN: + /* + * replication connections might already be linked to a server + * while they are still in a waiting state. + */ + if (client->replication && client->link) { + PgSocket *server = client->link; + server->link = NULL; + client->link = NULL; + disconnect_server(server, false, "replication client disconnected"); + } break; default: fatal("bad client state: %d", client->state); @@ -1249,6 +1302,9 @@ void disconnect_client_sqlstate(PgSocket *client, bool notify, const char *sqlst client->db = NULL; } + free(client->startup_options); + client->startup_options = NULL; + change_client_state(client, CL_JUSTFREE); if (!sbuf_close(&client->sbuf)) log_noise("sbuf_close failed, retry later"); @@ -1443,6 +1499,21 @@ bool evict_connection(PgDatabase *db) return false; } +/* evict the oldest idle connection from the pool */ +bool evict_pool_connection(PgPool *pool) +{ + PgSocket *oldest_connection = NULL; + + oldest_connection = compare_connections_by_time(oldest_connection, last_socket(&pool->idle_server_list)); + + if (oldest_connection) { + disconnect_server(oldest_connection, true, "evicted"); + return true; + } + return false; +} + + /* evict the single most idle connection from among all pools to make room in the user */ bool evict_user_connection(PgUser *user) { @@ -1483,6 +1554,7 @@ void launch_new_connection(PgPool *pool, bool evict_if_needed) PgSocket *server; int max; + log_debug("launch_new_connection: start"); /* * Allow only a single connection attempt at a time. * @@ -1538,9 +1610,9 @@ void launch_new_connection(PgPool *pool, bool evict_if_needed) /* is it allowed to add servers? */ if (max >= pool_pool_size(pool) && pool->welcome_msg_ready) { /* should we use reserve pool? */ + PgSocket *c = first_socket(&pool->waiting_client_list); if (cf_res_pool_timeout && pool_res_pool_size(pool)) { usec_t now = get_cached_time(); - PgSocket *c = first_socket(&pool->waiting_client_list); if (c && (now - c->request_time) >= cf_res_pool_timeout) { if (max < pool_pool_size(pool) + pool_res_pool_size(pool)) { slog_warning(c, "taking connection from reserve_pool"); @@ -1548,6 +1620,15 @@ void launch_new_connection(PgPool *pool, bool evict_if_needed) } } } + + if (c && c->replication) { + while (evict_if_needed && pool_pool_size(pool) >= max) { + if (!evict_pool_connection(pool)) + break; + } + if (pool_pool_size(pool) < max) + goto allow_new; + } log_debug("launch_new_connection: pool full (%d >= %d)", max, pool_pool_size(pool)); return; diff --git a/src/proto.c b/src/proto.c index fccde842c332..2d5a182d3984 100644 --- a/src/proto.c +++ b/src/proto.c @@ -593,15 +593,69 @@ bool answer_authreq(PgSocket *server, PktHdr *pkt) bool send_startup_packet(PgSocket *server) { - PgDatabase *db = server->pool->db; - const char *username = server->pool->user->name; - PktBuf *pkt; - - pkt = pktbuf_temp(); - pktbuf_write_StartupMessage(pkt, username, - db->startup_params->buf, - db->startup_params->write_pos); - return pktbuf_send_immediate(pkt, server); + PgPool *pool = server->pool; + PgDatabase *db = pool->db; + const char *username = pool->user->name; + PktBuf *pkt = pktbuf_temp(); + PgSocket *client = NULL; + + pktbuf_start_packet(pkt, PKT_STARTUP); + pktbuf_put_bytes(pkt, db->startup_params->buf, db->startup_params->write_pos); + + /* + * If the next client in the list is a replication connection, we need + * to do some special stuff for it. + */ + client = first_socket(&pool->waiting_client_list); + if (client && client->replication) { + server->replication = client->replication; + pktbuf_put_string(pkt, "replication"); + slog_debug(server, "send_startup_packet: creating replication connection"); + pktbuf_put_string(pkt, replication_type_parameters[server->replication]); + + /* + * For a replication connection we apply the varcache in the + * startup instead of through SET commands after connecting. + * The main reason to do so is because physical replication + * connections don't allow SET commands. A second reason is + * because it allows us to skip running the SET logic + * completely, which normally requires waiting on multiple + * server responses. This SET logic is normally executed in the + * codepath where we link the client to the server + * (find_server), but because we link the client here already + * we don't run that code for replication connections. Adding + * the varcache parameters to the startup message allows us to + * skip the dance that involves sending Query packets and + * waiting for responses. + */ + varcache_apply_startup(pkt, client); + if (client->startup_options) { + pktbuf_put_string(pkt, "options"); + pktbuf_put_string(pkt, client->startup_options); + } + } + + pktbuf_put_string(pkt, "user"); + pktbuf_put_string(pkt, username); + pktbuf_put_string(pkt, ""); /* terminator required in StartupMessage */ + pktbuf_finish_packet(pkt); + + if (!pktbuf_send_immediate(pkt, server)) { + return false; + } + + if (server->replication) { + /* + * We link replication connections to a client directly when they are + * created. One reason for is because the startup parameters need to be + * forwarded, because physical replication connections don't allow SET + * commands. Another reason is so that we don't need a separate state. + */ + client->link = server; + server->link = client; + } + + return true; } bool send_sslreq_packet(PgSocket *server) diff --git a/src/server.c b/src/server.c index 398b4fe11589..cd5fab5a1017 100644 --- a/src/server.c +++ b/src/server.c @@ -122,7 +122,18 @@ static bool handle_server_startup(PgSocket *server, PktHdr *pkt) break; case 'E': /* ErrorResponse */ - if (!server->pool->welcome_msg_ready) + /* + * If we haven't been able to connect to the server since the + * startup (or call to tag_pool_dirty), then we drop all + * clients that are currently trying to log in because they + * will almost certainly hit the same error. + * + * However, we don't do this if it's a replication connection, + * because those can fail due to a variety of missing + * permissions, while normal connections would still be able + * connect and query the database just fine. + */ + if (!server->pool->welcome_msg_ready && !server->replication) kill_pool_logins_server_error(server->pool, pkt); else log_server_error("S: login failed", pkt); @@ -192,6 +203,18 @@ static bool handle_server_startup(PgSocket *server, PktHdr *pkt) return res; } +/* + * server_pool_mode returns the pool_mode for the server. It specifically + * forces session pooling if the server is a replication connection, because + * replication connections require session pooling to work correctly. + */ +int server_pool_mode(PgSocket *server) +{ + if (server->replication) + return POOL_SESSION; + return pool_pool_mode(server->pool); +} + int pool_pool_mode(PgPool *pool) { int pool_mode = pool->user->pool_mode; @@ -270,7 +293,7 @@ static bool handle_server_work(PgSocket *server, PktHdr *pkt) /* set ready only if no tx */ if (state == 'I') { ready = true; - } else if (pool_pool_mode(server->pool) == POOL_STMT) { + } else if (server_pool_mode(server) == POOL_STMT) { disconnect_server(server, true, "transaction blocks not allowed in statement pooling mode"); return false; } else if (state == 'T' || state == 'E') { @@ -343,6 +366,7 @@ static bool handle_server_work(PgSocket *server, PktHdr *pkt) /* copy mode */ case 'G': /* CopyInResponse */ case 'H': /* CopyOutResponse */ + case 'W': /* CopyBothResponse */ server->copy_mode = true; break; /* chat packets */ @@ -619,7 +643,7 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data) break; } - if (pool_pool_mode(pool) != POOL_SESSION || server->state == SV_TESTED || server->resetting) { + if (server_pool_mode(server) != POOL_SESSION || server->state == SV_TESTED || server->resetting) { server->resetting = false; switch (server->state) { case SV_ACTIVE: diff --git a/src/util.c b/src/util.c index b87b226fe93a..1615b323ac6b 100644 --- a/src/util.c +++ b/src/util.c @@ -496,3 +496,18 @@ bool check_reserved_database(const char *value) } return true; } + +/* + * Same as strcmp, but handles NULLs. If both sides are NULL, returns "true". + */ +bool strings_equal(const char *str_left, const char *str_right) +{ + if (str_left == NULL && str_right == NULL) + return true; + + if (str_left == NULL || str_right == NULL) + return false; + + return strcmp(str_left, str_right) == 0; +} + diff --git a/src/varcache.c b/src/varcache.c index 37de1709aada..1f671ec8ca4c 100644 --- a/src/varcache.c +++ b/src/varcache.c @@ -233,6 +233,21 @@ bool varcache_apply(PgSocket *server, PgSocket *client, bool *changes_p) return pktbuf_send_immediate(pkt, server); } +void varcache_apply_startup(PktBuf *pkt, PgSocket *client) +{ + const struct var_lookup *lk, *tmp; + + HASH_ITER(hh, lookup_map, lk, tmp) { + struct PStr *val = get_value(&client->vars, lk); + if (!val) + continue; + + slog_debug(client, "varcache_apply_startup: %s=%s", lk->name, val->str); + pktbuf_put_string(pkt, lk->name); + pktbuf_put_string(pkt, val->str); + } +} + void varcache_fill_unset(VarCache *src, PgSocket *dst) { struct PStr *srcval, *dstval; diff --git a/test/conftest.py b/test/conftest.py index cf55ef91e8b1..bf050b5595e5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -63,6 +63,7 @@ def pg(tmp_path_factory, cert_dir): f.write(f"ssl_cert_file='{cert}'\n") f.write(f"ssl_key_file='{key}'\n") + pg.nossl_access("replication", "trust", user="postgres") pg.nossl_access("all", "trust") pg.nossl_access("p4", "password") pg.nossl_access("p5", "md5") diff --git a/test/hba_test.c b/test/hba_test.c index 839b3e48c0db..90b0dcc436e7 100644 --- a/test/hba_test.c +++ b/test/hba_test.c @@ -65,9 +65,11 @@ static char *get_token(char **ln_p) static int hba_test_eval(struct HBA *hba, char *ln, int linenr) { - const char *addr=NULL, *user=NULL, *db=NULL, *tls=NULL, *exp=NULL; + const char *addr=NULL, *user=NULL, *db=NULL, *modifier=NULL, *exp=NULL; PgAddr pgaddr; int res; + bool tls; + ReplicationType replication; if (ln[0] == '#') return 0; @@ -75,7 +77,9 @@ static int hba_test_eval(struct HBA *hba, char *ln, int linenr) db = get_token(&ln); user = get_token(&ln); addr = get_token(&ln); - tls = get_token(&ln); + modifier = get_token(&ln); + tls = strings_equal(modifier, "tls"); + replication = strings_equal(modifier, "replication") ? REPLICATION_PHYSICAL : REPLICATION_NONE; if (!exp) return 0; if (!db || !user) @@ -84,7 +88,7 @@ static int hba_test_eval(struct HBA *hba, char *ln, int linenr) if (!pga_pton(&pgaddr, addr, 9999)) die("hbatest: invalid addr on line #%d", linenr); - res = hba_eval(hba, &pgaddr, !!tls, db, user); + res = hba_eval(hba, &pgaddr, tls, replication, db, user); if (strcmp(method2string(res), exp) == 0) { res = 0; } else { diff --git a/test/hba_test.eval b/test/hba_test.eval index 3cd28b98f234..f25f7b6cce2c 100644 --- a/test/hba_test.eval +++ b/test/hba_test.eval @@ -79,3 +79,16 @@ md5 mdb muser ff11:2::1 md5 mdb muser ff22:3::1 trust mdb muser ::1 reject mdb muser ::2 + +# replication +reject mdb muser ::1 replication +reject db userp unix replication +reject replication userp unix replication +trust db admin ::1 replication +trust replication admin ::1 replication +reject db admin ::1 +reject replication admin ::1 +trust db admin2 ::1 replication +trust replication admin2 ::1 replication +trust db2 admin2 ::1 +reject replication admin2 ::1 diff --git a/test/hba_test.rules b/test/hba_test.rules index 1efe467e8b75..03d6cfd45a5a 100644 --- a/test/hba_test.rules +++ b/test/hba_test.rules @@ -5,6 +5,7 @@ # hostssl DATABASE USER ADDRESS METHOD [OPTIONS] # hostnossl DATABASE USER ADDRESS METHOD [OPTIONS] +# The following lines test that weird whitespace does not break parsing # ws # z @@ -48,3 +49,7 @@ host mdb2 muser 128.0.0.0/1 cert host mdb muser ff11::0/16 md5 host mdb muser ff20::/12 md5 host mdb muser ::1/128 trust + +# replication +host replication admin ::1/128 trust +host db2,replication admin2 ::1/128 trust diff --git a/test/ssl/test.ini b/test/ssl/test.ini index 7d76e30c94b5..1e7158619521 100644 --- a/test/ssl/test.ini +++ b/test/ssl/test.ini @@ -1,6 +1,7 @@ [databases] p0 = port=6666 host=localhost dbname=p0 user=bouncer pool_size=2 p1 = port=6666 host=localhost dbname=p1 user=bouncer +p7a= port=6666 host=localhost dbname=p7 [pgbouncer] logfile = test.log diff --git a/test/test.ini b/test/test.ini index 6685f8e63791..31090ac8c459 100644 --- a/test/test.ini +++ b/test/test.ini @@ -25,12 +25,22 @@ p7b= port=6666 host=127.0.0.1 dbname=p7 p7c= port=6666 host=127.0.0.1 dbname=p7 p8 = port=6666 host=127.0.0.1 dbname=p0 connect_query='set enable_seqscan=off; set enable_nestloop=off' +user_passthrough = port=6666 host=127.0.0.1 dbname=p0 +user_passthrough2 = port=6666 host=127.0.0.1 dbname=p2 +user_passthrough_pool_size2 = port=6666 host=127.0.0.1 dbname=p0 pool_size=2 +user_passthrough_pool_size5 = port=6666 host=127.0.0.1 dbname=p0 pool_size=2 + pauthz = port=6666 host=127.0.0.1 dbname=p7 auth_user=pswcheck auth_dbname=authdb authdb = port=6666 host=127.0.0.1 dbname=p1 auth_user=pswcheck hostlist1 = port=6666 host=127.0.0.1,::1 dbname=p0 user=bouncer hostlist2 = port=6666 host=127.0.0.1,127.0.0.1 dbname=p0 user=bouncer +; Needed for pg_receivewal and pg_basebackup tests until this patch is merged +; in PostgreSQL (and we don't support PG16 anymore): +; https://www.postgresql.org/message-id/flat/CAGECzQTw-dZkVT_RELRzfWRzY714-VaTjoBATYfZq93R8C-auA@mail.gmail.com +replication = port=6666 host=127.0.0.1 dbname=replication + ; commented out except for auto-database tests ;* = port=6666 host=127.0.0.1 diff --git a/test/test_limits.py b/test/test_limits.py index 1daf9afdfac3..401fa2261826 100644 --- a/test/test_limits.py +++ b/test/test_limits.py @@ -1,4 +1,5 @@ import asyncio +import time import psycopg import pytest @@ -63,7 +64,7 @@ def test_min_pool_size_with_lower_max_user_connections(bouncer): # Running a query for sufficient time for us to reach the final # connection count in the pool and detect any evictions. - with bouncer.log_contains("new connection to server", times=2): + with bouncer.log_contains(r"new connection to server \(from", times=2): with bouncer.log_contains("closing because: evicted", times=0): bouncer.sleep(2, dbname="p0x", user="maxedout2") @@ -76,7 +77,7 @@ def test_min_pool_size_with_lower_max_db_connections(bouncer): # Running a query for sufficient time for us to reach the final # connection count in the pool and detect any evictions. - with bouncer.log_contains("new connection to server", times=2): + with bouncer.log_contains(r"new connection to server \(from", times=2): with bouncer.log_contains("closing because: evicted", times=0): bouncer.sleep(2, dbname="p0y", user="puser1") diff --git a/test/test_replication.py b/test/test_replication.py new file mode 100644 index 000000000000..a8bbda8e8c6d --- /dev/null +++ b/test/test_replication.py @@ -0,0 +1,280 @@ +import asyncio +import signal +import subprocess +import time + +import psycopg +import psycopg.errors +import pytest +from psycopg import sql + +from .utils import PG_MAJOR_VERSION, WINDOWS, run + + +def test_logical_rep(bouncer): + connect_args = { + "dbname": "user_passthrough", + "replication": "database", + "user": "postgres", + "application_name": "abc", + "options": "-c enable_seqscan=off", + } + # Starting in PG10 you can do other commands over logical rep connections + if PG_MAJOR_VERSION >= 10: + bouncer.test(**connect_args) + assert bouncer.sql_value("SHOW application_name", **connect_args) == "abc" + assert bouncer.sql_value("SHOW enable_seqscan", **connect_args) == "off" + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + # Do a normal connection to the same pool, to ensure that that doesn't + # break anything + bouncer.test(dbname="user_passthrough", user="postgres") + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + + +def test_logical_rep_unprivileged(bouncer): + if PG_MAJOR_VERSION < 10: + expected_log = "no pg_hba.conf entry for replication connection" + elif PG_MAJOR_VERSION < 16: + expected_log = "must be superuser or replication role to start walsender" + else: + expected_log = "permission denied to start WAL sender" + + with bouncer.log_contains( + expected_log, + ), bouncer.log_contains( + r"closing because: login failed \(age", times=2 + ), pytest.raises(psycopg.OperationalError, match=r"login failed"): + bouncer.sql("IDENTIFY_SYSTEM", replication="database") + + +@pytest.mark.skipif( + "PG_MAJOR_VERSION < 10", reason="logical replication was introduced in PG10" +) +def test_logical_rep_subscriber(bouncer): + bouncer.admin("set pool_mode=transaction") + + # First write create a table and insert a row in the source database. + # Also create the replication slot and publication + bouncer.default_db = "user_passthrough" + bouncer.create_schema("test_logical_rep_subscriber") + bouncer.sql("CREATE TABLE test_logical_rep_subscriber.table(a int)") + bouncer.sql("INSERT INTO test_logical_rep_subscriber.table values (1)") + assert ( + bouncer.sql_value("SELECT count(*) FROM test_logical_rep_subscriber.table") == 1 + ) + + bouncer.create_publication( + "mypub", sql.SQL("FOR TABLE test_logical_rep_subscriber.table") + ) + + bouncer.create_logical_replication_slot("test_logical_rep_subscriber", "pgoutput") + + # Create an equivalent, but empty schema in the target database. + # And setup the subscription + bouncer.default_db = "user_passthrough2" + bouncer.create_schema("test_logical_rep_subscriber") + bouncer.sql("CREATE TABLE test_logical_rep_subscriber.table(a int)") + conninfo = bouncer.make_conninfo(dbname="user_passthrough") + bouncer.create_subscription( + "mysub", + sql.SQL( + """ + CONNECTION {} + PUBLICATION mypub + WITH (slot_name=test_logical_rep_subscriber, create_slot=false) + """ + ).format(sql.Literal(conninfo)), + ) + + # The initial copy should now copy over the row + time.sleep(2) + assert ( + bouncer.sql_value("SELECT count(*) FROM test_logical_rep_subscriber.table") >= 1 + ) + + # Insert another row and logical replication should replicate it correctly + bouncer.sql( + "INSERT INTO test_logical_rep_subscriber.table values (2)", + dbname="user_passthrough", + ) + time.sleep(2) + assert ( + bouncer.sql_value("SELECT count(*) FROM test_logical_rep_subscriber.table") >= 2 + ) + + +@pytest.mark.skipif( + "WINDOWS", reason="MINGW does not have contrib package containing test_decoding" +) +def test_logical_rep_pg_recvlogical(bouncer): + bouncer.default_db = "user_passthrough" + bouncer.create_schema("test_logical_rep_pg_recvlogical") + bouncer.sql("CREATE TABLE test_logical_rep_pg_recvlogical.table(a int)") + bouncer.create_logical_replication_slot( + "test_logical_rep_pg_recvlogical", "test_decoding" + ) + process = subprocess.Popen( + [ + "pg_recvlogical", + "--dbname", + bouncer.default_db, + "--host", + bouncer.host, + "--port", + str(bouncer.port), + "--user", + bouncer.default_user, + "--slot=test_logical_rep_pg_recvlogical", + "--file=-", + "--no-loop", + "--start", + ], + stdout=subprocess.PIPE, + ) + assert process.stdout is not None + bouncer.sql("INSERT INTO test_logical_rep_pg_recvlogical.table values (1)") + try: + assert process.stdout.readline().startswith(b"BEGIN ") + assert ( + process.stdout.readline() + == b'table test_logical_rep_pg_recvlogical."table": INSERT: a[integer]:1\n' + ) + assert process.stdout.readline().startswith(b"COMMIT ") + finally: + process.kill() + process.communicate(timeout=5) + + +def test_physical_rep(bouncer): + connect_args = { + "dbname": "user_passthrough", + "replication": "yes", + "user": "postgres", + "application_name": "abc", + "options": "-c enable_seqscan=off", + } + # Starting in PG10 you can do SHOW commands + if PG_MAJOR_VERSION >= 10: + with pytest.raises( + psycopg.errors.FeatureNotSupported, + match="cannot execute SQL commands in WAL sender for physical replication", + ): + bouncer.test(**connect_args) + assert bouncer.sql_value("SHOW application_name", **connect_args) == "abc" + assert bouncer.sql_value("SHOW enable_seqscan", **connect_args) == "off" + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + # Do a normal connection to the same pool, to ensure that that doesn't + # break anything + bouncer.test(dbname="user_passthrough", user="postgres") + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + + +def test_physcal_rep_unprivileged(bouncer): + with bouncer.log_contains( + r"no pg_hba.conf entry for replication connection from host" + ), bouncer.log_contains( + r"closing because: login failed \(age", times=2 + ), pytest.raises( + psycopg.OperationalError, match=r"login failed" + ): + bouncer.test(replication="yes") + + +@pytest.mark.skipif("PG_MAJOR_VERSION < 10", reason="pg_receivewal was added in PG10") +def test_physical_rep_pg_receivewal(bouncer, tmp_path): + bouncer.default_db = "user_passthrough" + bouncer.create_physical_replication_slot("test_physical_rep_pg_receivewal") + wal_dump_dir = tmp_path / "wal-dump" + wal_dump_dir.mkdir() + + process = subprocess.Popen( + [ + "pg_receivewal", + "--dbname", + bouncer.make_conninfo(), + "--slot=test_physical_rep_pg_receivewal", + "--directory", + str(wal_dump_dir), + ], + ) + time.sleep(3) + + if WINDOWS: + process.terminate() + else: + process.send_signal(signal.SIGINT) + process.communicate(timeout=5) + + if WINDOWS: + assert process.returncode == 1 + else: + assert process.returncode == 0 + + children = list(wal_dump_dir.iterdir()) + assert len(children) > 0 + + +def test_physical_rep_pg_basebackup(bouncer, tmp_path): + bouncer.default_db = "user_passthrough" + dump_dir = tmp_path / "db-dump" + dump_dir.mkdir() + + run( + [ + "pg_basebackup", + "--dbname", + bouncer.make_conninfo(), + "--checkpoint=fast", + "--pgdata", + str(dump_dir), + ], + shell=False, + ) + children = list(dump_dir.iterdir()) + assert len(children) > 0 + print(children) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + "PG_MAJOR_VERSION < 10", + reason="normal SQL commands are only supported in PG10+ on logical replication connections", +) +async def test_replication_pool_size(pg, bouncer): + connect_args = { + "dbname": "user_passthrough_pool_size2", + "replication": "database", + "user": "postgres", + } + start = time.time() + await bouncer.asleep(0.5, times=10, **connect_args) + assert time.time() - start > 2.5 + # Replication connections always get closed right away + assert pg.connection_count("p0") == 0 + + connect_args["dbname"] = "user_passthrough_pool_size5" + start = time.time() + await bouncer.asleep(0.5, times=10, **connect_args) + assert time.time() - start > 1 + # Replication connections always get closed right away + assert pg.connection_count("p0") == 0 + + +@pytest.mark.asyncio +@pytest.mark.skipif( + "PG_MAJOR_VERSION < 10", + reason="normal SQL commands are only supported in PG10+ on logical replication connections", +) +async def test_replication_pool_size_mixed_clients(bouncer): + connect_args = { + "dbname": "user_passthrough_pool_size2", + "user": "postgres", + } + + # Fill the pool with normal connections + await bouncer.asleep(0.5, times=2, **connect_args) + + # Then try to open a replication connection and ensure that it causes + # eviction of one of the normal connections + with bouncer.log_contains("closing because: evicted"): + bouncer.test(**connect_args, replication="database") diff --git a/test/test_ssl.py b/test/test_ssl.py index 3a82c91e1a53..b7a3976ddb07 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -329,3 +329,41 @@ def test_client_ssl_scram(bouncer, cert_dir): sslmode="verify-full", sslrootcert=root, ) + + +def test_ssl_replication(pg, bouncer, cert_dir): + root = cert_dir / "TestCA1" / "ca.crt" + key = cert_dir / "TestCA1" / "sites" / "01-localhost.key" + cert = cert_dir / "TestCA1" / "sites" / "01-localhost.crt" + + bouncer.write_ini(f"server_tls_sslmode = verify-full") + bouncer.write_ini(f"server_tls_ca_file = {root}") + bouncer.write_ini(f"client_tls_sslmode = require") + bouncer.write_ini(f"client_tls_key_file = {key}") + bouncer.write_ini(f"client_tls_cert_file = {cert}") + bouncer.write_ini(f"client_tls_ca_file = {root}") + bouncer.admin("reload") + pg.ssl_access("all", "trust") + pg.ssl_access("replication", "trust", user="postgres") + pg.configure("ssl=on") + pg.configure(f"ssl_ca_file='{root}'") + + if PG_MAJOR_VERSION < 10 or WINDOWS: + pg.restart() + else: + pg.reload() + + # Logical rep + connect_args = { + "host": "localhost", + "dbname": "p7a", + "replication": "database", + "user": "postgres", + "application_name": "abc", + "sslmode": "verify-full", + "sslrootcert": root, + } + bouncer.psql("IDENTIFY_SYSTEM", **connect_args) + # physical rep + connect_args["replication"] = "true" + bouncer.psql("IDENTIFY_SYSTEM", **connect_args) diff --git a/test/utils.py b/test/utils.py index e02a131ee46e..709c50a0b69f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -21,6 +21,8 @@ import filelock import psycopg +import psycopg.sql +from psycopg import sql TEST_DIR = Path(os.path.dirname(os.path.realpath(__file__))) os.chdir(TEST_DIR) @@ -158,6 +160,38 @@ def get_tls_support(): next_port = PORT_LOWER_BOUND +def notice_handler(diag: psycopg.errors.Diagnostic): + print(f"{diag.severity} ({diag.sqlstate}): {diag.message_primary}") + if diag.message_detail: + print(f"DETAIL: {diag.message_detail}") + if diag.message_hint: + print(f"HINT: {diag.message_hint}") + + +def cleanup_test_leftovers(*nodes): + """ + Cleaning up test leftovers needs to be done in a specific order, because + some of these leftovers depend on others having been removed. They might + even depend on leftovers on other nodes being removed. So this takes a list + of nodes, so that we can clean up all test leftovers globally in the + correct order. + """ + for node in nodes: + node.cleanup_subscriptions() + + for node in nodes: + node.cleanup_publications() + + for node in nodes: + node.cleanup_replication_slots() + + for node in nodes: + node.cleanup_schemas() + + for node in nodes: + node.cleanup_users() + + class PortLock: def __init__(self): global next_port @@ -191,9 +225,19 @@ def __init__(self, host, port): self.default_db = "postgres" self.default_user = "postgres" + # Used to track objects that we want to clean up at the end of a test + self.subscriptions = set() + self.publications = set() + self.replication_slots = set() + self.schemas = set() + self.users = set() + def set_default_connection_options(self, options): + """Sets the default connection options on the given options dictionary""" options.setdefault("dbname", self.default_db) options.setdefault("user", self.default_user) + options.setdefault("host", self.host) + options.setdefault("port", self.port) if ENABLE_VALGRIND: # If valgrind is enabled PgBouncer is a significantly slower to # respond to connection requests, so we wait a little longer. @@ -202,24 +246,27 @@ def set_default_connection_options(self, options): options.setdefault("connect_timeout", 3) # needed for Ubuntu 18.04 options.setdefault("client_encoding", "UTF8") + return options + + def make_conninfo(self, **kwargs) -> str: + self.set_default_connection_options(kwargs) + return psycopg.conninfo.make_conninfo(**kwargs) def conn(self, *, autocommit=True, **kwargs): """Open a psycopg connection to this server""" self.set_default_connection_options(kwargs) - return psycopg.connect( + conn = psycopg.connect( autocommit=autocommit, - host=self.host, - port=self.port, **kwargs, ) + conn.add_notice_handler(notice_handler) + return conn def aconn(self, *, autocommit=True, **kwargs): """Open an asynchronous psycopg connection to this server""" self.set_default_connection_options(kwargs) return psycopg.AsyncConnection.connect( autocommit=autocommit, - host=self.host, - port=self.port, **kwargs, ) @@ -448,6 +495,117 @@ def reject_traffic(self): elif BSD: sudo(f"pfctl -a pgbouncer_test/port_{self.port} -F all") + def create_user(self, name, args: typing.Optional[psycopg.sql.Composable] = None): + self.users.add(name) + if args is None: + args = sql.SQL("") + self.sql(sql.SQL("CREATE USER {} {}").format(sql.Identifier(name), args)) + + def create_schema(self, name, dbname=None): + dbname = dbname or self.default_db + self.schemas.add((dbname, name)) + self.sql(sql.SQL("CREATE SCHEMA {}").format(sql.Identifier(name))) + + def create_publication(self, name: str, args: psycopg.sql.Composable, dbname=None): + dbname = dbname or self.default_db + self.publications.add((dbname, name)) + self.sql(sql.SQL("CREATE PUBLICATION {} {}").format(sql.Identifier(name), args)) + + def create_logical_replication_slot(self, name, plugin): + self.replication_slots.add(name) + self.sql( + "SELECT pg_catalog.pg_create_logical_replication_slot(%s,%s)", + (name, plugin), + ) + + def create_physical_replication_slot(self, name): + self.replication_slots.add(name) + self.sql( + "SELECT pg_catalog.pg_create_physical_replication_slot(%s)", + (name,), + ) + + def create_subscription(self, name: str, args: psycopg.sql.Composable, dbname=None): + dbname = dbname or self.default_db + self.subscriptions.add((dbname, name)) + self.sql( + sql.SQL("CREATE SUBSCRIPTION {} {}").format(sql.Identifier(name), args) + ) + + def cleanup_users(self): + for user in self.users: + self.sql(sql.SQL("DROP USER IF EXISTS {}").format(sql.Identifier(user))) + + def cleanup_schemas(self): + for dbname, schema in self.schemas: + self.sql( + sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE").format( + sql.Identifier(schema) + ), + dbname=dbname, + ) + + def cleanup_publications(self): + for dbname, publication in self.publications: + self.sql( + sql.SQL("DROP PUBLICATION IF EXISTS {}").format( + sql.Identifier(publication) + ), + dbname=dbname, + ) + + def cleanup_replication_slots(self): + for slot in self.replication_slots: + start = time.time() + while True: + try: + self.sql( + "SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name = %s", + (slot,), + ) + except psycopg.errors.ObjectInUse: + if time.time() < start + 10: + time.sleep(0.5) + continue + raise + break + + def cleanup_subscriptions(self): + for dbname, subscription in self.subscriptions: + try: + self.sql( + sql.SQL("ALTER SUBSCRIPTION {} DISABLE").format( + sql.Identifier(subscription) + ), + dbname=dbname, + ) + except psycopg.errors.UndefinedObject: + # Subscription didn't exist already + continue + self.sql( + sql.SQL("ALTER SUBSCRIPTION {} SET (slot_name = NONE)").format( + sql.Identifier(subscription) + ), + dbname=dbname, + ) + self.sql( + sql.SQL("DROP SUBSCRIPTION {}").format(sql.Identifier(subscription)), + dbname=dbname, + ) + + def debug(self): + print("Connect manually to:\n ", repr(self.make_conninfo())) + print("Press Enter to continue running the test...") + input() + + def psql_debug(self, **kwargs): + conninfo = self.make_conninfo(**kwargs) + run( + ["psql", f"{conninfo}"], + shell=False, + silent=True, + ) + class Postgres(QueryRunner): def __init__(self, pgdata): @@ -471,6 +629,26 @@ def initdb(self): pgconf.write("log_connections = on\n") pgconf.write("log_disconnections = on\n") pgconf.write("logging_collector = off\n") + + # Allow CREATE SUBSCRIPTION to work + pgconf.write("wal_level = 'logical'\n") + # Faster logical replication status update so tests with logical replication + # run faster + pgconf.write("wal_receiver_status_interval = 1\n") + + # Faster logical replication apply worker launch so tests with logical + # replication run faster. This is used in ApplyLauncherMain in + # src/backend/replication/logical/launcher.c. + pgconf.write("wal_retrieve_retry_interval = '250ms'\n") + + # Make sure there's enough logical replication resources for our + # tests + if PG_MAJOR_VERSION >= 10: + pgconf.write("max_logical_replication_workers = 5\n") + pgconf.write("max_wal_senders = 5\n") + pgconf.write("max_replication_slots = 10\n") + pgconf.write("max_worker_processes = 20\n") + # We need to make the log go to stderr so that the tests can # check what is being logged. This should be the default, but # some packagings change the default configuration. @@ -524,24 +702,24 @@ async def arestart(self): process = await self.apgctl("-m fast restart") await process.communicate() - def nossl_access(self, dbname, auth_type): + def nossl_access(self, dbname, auth_type, user="all"): """Prepends a local non-SSL access to the HBA file""" with self.hba_path.open() as pghba: old_contents = pghba.read() with self.hba_path.open(mode="w") as pghba: if USE_UNIX_SOCKETS: - pghba.write(f"local {dbname} all {auth_type}\n") - pghba.write(f"hostnossl {dbname} all 127.0.0.1/32 {auth_type}\n") - pghba.write(f"hostnossl {dbname} all ::1/128 {auth_type}\n") + pghba.write(f"local {dbname} {user} {auth_type}\n") + pghba.write(f"hostnossl {dbname} {user} 127.0.0.1/32 {auth_type}\n") + pghba.write(f"hostnossl {dbname} {user} ::1/128 {auth_type}\n") pghba.write(old_contents) - def ssl_access(self, dbname, auth_type): + def ssl_access(self, dbname, auth_type, user="all"): """Prepends a local SSL access rule to the HBA file""" with self.hba_path.open() as pghba: old_contents = pghba.read() with self.hba_path.open(mode="w") as pghba: - pghba.write(f"hostssl {dbname} all 127.0.0.1/32 {auth_type}\n") - pghba.write(f"hostssl {dbname} all ::1/128 {auth_type}\n") + pghba.write(f"hostssl {dbname} {user} 127.0.0.1/32 {auth_type}\n") + pghba.write(f"hostssl {dbname} {user} ::1/128 {auth_type}\n") pghba.write(old_contents) @property @@ -663,6 +841,8 @@ def __init__( ini.write(f"logfile = {self.log_path}\n") ini.write(f"auth_file = {self.auth_path}\n") ini.write("pidfile = \n") + # Uncomment for much more noise but, more detailed debugging + # ini.write("verbose = 3\n") if not USE_UNIX_SOCKETS: ini.write(f"unix_socket_dir = \n") @@ -807,8 +987,11 @@ def print_logs(self): assert not failed_valgrind async def cleanup(self): - await self.stop() - self.print_logs() + try: + cleanup_test_leftovers(self) + await self.stop() + finally: + self.print_logs() if self.port_lock: self.port_lock.release()