diff --git a/.cirrus.yml b/.cirrus.yml index 438db2c10f36..3d58a995af8d 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -103,7 +103,7 @@ task: # - image: rockylinux:8 # - image: centos:centos7 # setup_script: -# - yum -y install autoconf automake diffutils file libevent-devel libtool make openssl-devel pkg-config postgresql-server systemd-devel wget +# - yum -y install autoconf automake diffutils file libevent-devel libtool make openssl-devel pkg-config postgresql-server postgresql-contrib systemd-devel wget # - if cat /etc/centos-release | grep -q ' 7'; then yum -y install python python-pip; else yum -y install python3 python3-pip sudo iptables; fi # - wget -O /tmp/pandoc.tar.gz https://github.com/jgm/pandoc/releases/download/2.10.1/pandoc-2.10.1-linux-amd64.tar.gz # - tar xvzf /tmp/pandoc.tar.gz --strip-components 1 -C /usr/local/ @@ -133,7 +133,7 @@ task: # - image: alpine:latest # setup_script: # - apk update -# - apk add autoconf automake bash build-base libevent-dev libtool openssl openssl-dev pkgconf postgresql python3 py3-pip wget sudo iptables +# - apk add autoconf automake bash build-base libevent-dev libtool openssl openssl-dev pkgconf postgresql postgresql-contrib python3 py3-pip wget sudo iptables # - wget -O /tmp/pandoc.tar.gz https://github.com/jgm/pandoc/releases/download/2.10.1/pandoc-2.10.1-linux-amd64.tar.gz # - tar xvzf /tmp/pandoc.tar.gz --strip-components 1 -C /usr/local/ # - python3 -m pip install -r requirements.txt @@ -161,7 +161,7 @@ task: HAVE_IPV6_LOCALHOST: yes USE_SUDO: true setup_script: - - pkg install -y autoconf automake bash gmake hs-pandoc libevent libtool pkgconf postgresql12-server python devel/py-pip sudo + - pkg install -y autoconf automake bash gmake hs-pandoc libevent libtool pkgconf postgresql12-server postgresql12-contrib python devel/py-pip sudo - pip install -r requirements.txt - kldload pf - echo 'anchor "pgbouncer_test/*"' >> /etc/pf.conf diff --git a/.editorconfig b/.editorconfig index 4c8a713d613e..678f626abaf2 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,16 @@ trim_trailing_whitespace = true indent_style = tab indent_size = 8 +[hba_test.{eval,rules}] +indent_style = tab +indent_size = 8 + +[hba_test.rules] +# Disable trailing_whitespace check for hba_test.rules, because one of the +# tests in that file is that parsing doesn't break in case of trailing +# whitespace. +trim_trailing_whitespace = false + [*.py] indent_style = space indent_size = 4 diff --git a/Makefile b/Makefile index 5158c2af3f3d..68f26d063c69 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ pgbouncer_SOURCES = \ src/util.c \ src/varcache.c \ src/common/base64.c \ + src/common/bool.c \ src/common/saslprep.c \ src/common/scram-common.c \ src/common/unicode_norm.c \ @@ -51,6 +52,7 @@ pgbouncer_SOURCES = \ include/util.h \ include/varcache.h \ include/common/base64.h \ + include/common/builtins.h \ include/common/pg_wchar.h \ include/common/postgres_compat.h \ include/common/saslprep.h \ diff --git a/doc/config.md b/doc/config.md index 7339246b3277..8aab1ce9efe8 100644 --- a/doc/config.md +++ b/doc/config.md @@ -1285,7 +1285,7 @@ The file follows the format of the PostgreSQL `pg_hba.conf` file (see ). * Supported record types: `local`, `host`, `hostssl`, `hostnossl`. -* Database field: Supports `all`, `sameuser`, `@file`, multiple names. Not supported: `replication`, `samerole`, `samegroup`. +* Database field: Supports `all`, `replication`, `sameuser`, `@file`, multiple names. Not supported: `samerole`, `samegroup`. * User name field: Supports `all`, `@file`, multiple names. Not supported: `+groupname`. * Address field: Supports IPv4, IPv6. Not supported: DNS names, domain prefixes. * Auth-method field: Only methods supported by PgBouncer's `auth_type` diff --git a/include/bouncer.h b/include/bouncer.h index eee300887967..1a88a02565b4 100644 --- a/include/bouncer.h +++ b/include/bouncer.h @@ -104,6 +104,7 @@ typedef union PgAddr PgAddr; typedef enum SocketState SocketState; typedef struct PktHdr PktHdr; typedef struct ScramState ScramState; +typedef enum ReplicationType ReplicationType; extern int cf_sbuf_len; @@ -305,6 +306,11 @@ struct PgPool { * Clients that sent cancel request, to cancel another client its query. * These requests are waiting for a new server connection to be opened, * before the request can be forwarded. + * + * This is a separate list from waiting_client_list, because we want to + * give cancel requests priority over regular clients. The main reason + * for this is, because a cancel request might free up a connection, + * which can be used for one of the waiting clients. */ struct StatList waiting_cancel_req_list; @@ -504,6 +510,13 @@ struct PgDatabase { struct AATree user_tree; /* users that have been queried on this database */ }; +enum ReplicationType { + REPLICATION_NONE = 0, + REPLICATION_LOGICAL, + REPLICATION_PHYSICAL, +}; + +extern const char *replication_type_parameters[3]; /* * A client or server connection. @@ -543,6 +556,9 @@ struct PgSocket { bool wait_sslchar : 1; /* server: waiting for ssl response: S/N */ + ReplicationType replication; /* If this is a replication connection */ + char *startup_options; /* only tracked for replication connections */ + int expect_rfq_count; /* client: count of ReadyForQuery packets client should see */ usec_t connect_time; /* when connection was made */ diff --git a/include/common/builtins.h b/include/common/builtins.h new file mode 100644 index 000000000000..aede401a94db --- /dev/null +++ b/include/common/builtins.h @@ -0,0 +1,17 @@ +/*------------------------------------------------------------------------- + * + * builtins.h + * Declarations for operations on built-in types. + * + * + * Portions Copyright (c) 1996-2023, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * include/common/builtins.h + * + *------------------------------------------------------------------------- + */ + +/* bool.c */ +extern bool parse_bool(const char *value, bool *result); +extern bool parse_bool_with_len(const char *value, size_t len, bool *result); diff --git a/include/common/postgres_compat.h b/include/common/postgres_compat.h index 71afe25192af..98267fba5095 100644 --- a/include/common/postgres_compat.h +++ b/include/common/postgres_compat.h @@ -7,6 +7,7 @@ /* from c.h */ #include +#include #define int8 int8_t #define uint8 uint8_t @@ -15,6 +16,7 @@ #define lengthof(array) (sizeof (array) / sizeof ((array)[0])) #define pg_hton32(x) htobe32(x) +#define pg_strncasecmp strncasecmp #define pg_attribute_noreturn() _NORETURN @@ -30,6 +32,5 @@ #define pg_sha256_update(ctx, data, len) sha256_update(ctx, data, len) #define pg_sha256_final(ctx, dst) sha256_final(ctx, dst) - /* define this to use non-server code paths */ #define FRONTEND diff --git a/include/hba.h b/include/hba.h index 7e8c3bb752f6..fc92b96bf7d6 100644 --- a/include/hba.h +++ b/include/hba.h @@ -20,4 +20,4 @@ struct HBA; struct HBA *hba_load_rules(const char *fn); void hba_free(struct HBA *hba); -int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, const char *dbname, const char *username); +int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, ReplicationType replication, const char *dbname, const char *username); diff --git a/include/objects.h b/include/objects.h index 1a7f7aa535cf..db1432aa8285 100644 --- a/include/objects.h +++ b/include/objects.h @@ -42,6 +42,7 @@ PgPool *get_pool(PgDatabase *, PgUser *); PgPool *get_peer_pool(PgDatabase *); PgSocket *compare_connections_by_time(PgSocket *lhs, PgSocket *rhs); bool evict_connection(PgDatabase *db) _MUSTCHECK; +bool evict_pool_connection(PgPool *pool) _MUSTCHECK; bool evict_user_connection(PgUser *user) _MUSTCHECK; bool find_server(PgSocket *client) _MUSTCHECK; bool life_over(PgSocket *server); diff --git a/include/pktbuf.h b/include/pktbuf.h index 7c51a0ff3e6c..35c8c26cd35f 100644 --- a/include/pktbuf.h +++ b/include/pktbuf.h @@ -97,9 +97,6 @@ void pktbuf_write_ExtQuery(PktBuf *buf, const char *query, int nargs, ...); #define pktbuf_write_CancelRequest(buf, key) \ pktbuf_write_generic(buf, PKT_CANCEL, "b", key, 8) -#define pktbuf_write_StartupMessage(buf, user, parms, parms_len) \ - pktbuf_write_generic(buf, PKT_STARTUP, "bsss", parms, parms_len, "user", user, "") - #define pktbuf_write_PasswordMessage(buf, psw) \ pktbuf_write_generic(buf, 'p', "s", psw) diff --git a/include/server.h b/include/server.h index bb6195a5fbaf..7ca791a8a44c 100644 --- a/include/server.h +++ b/include/server.h @@ -18,6 +18,7 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *pkt) _MUSTCHECK; void kill_pool_logins(PgPool *pool, const char *sqlstate, const char *msg); +int server_pool_mode(PgSocket *server) _MUSTCHECK; int pool_pool_mode(PgPool *pool) _MUSTCHECK; int pool_pool_size(PgPool *pool) _MUSTCHECK; int pool_min_pool_size(PgPool *pool) _MUSTCHECK; diff --git a/include/util.h b/include/util.h index 1a979b010cc2..93d419200fa7 100644 --- a/include/util.h +++ b/include/util.h @@ -70,3 +70,5 @@ bool cf_set_authdb(struct CfValue *cv, const char *value); /* reserved database name checking */ bool check_reserved_database(const char *value); + +bool strings_equal(const char *str_left, const char *str_right) _MUSTCHECK; diff --git a/include/varcache.h b/include/varcache.h index 21599666a900..eb4caa2532ff 100644 --- a/include/varcache.h +++ b/include/varcache.h @@ -19,6 +19,7 @@ void init_var_lookup(const char *cf_track_extra_parameters); int get_num_var_cached(void); bool varcache_set(VarCache *cache, const char *key, const char *value) /* _MUSTCHECK */; bool varcache_apply(PgSocket *server, PgSocket *client, bool *changes_p) _MUSTCHECK; +void varcache_apply_startup(PktBuf *pkt, PgSocket *client); void varcache_fill_unset(VarCache *src, PgSocket *dst); void varcache_clean(VarCache *cache); void varcache_add_params(PktBuf *pkt, VarCache *vars); diff --git a/src/client.c b/src/client.c index b37483ca358e..1186d351e6ad 100644 --- a/src/client.c +++ b/src/client.c @@ -23,6 +23,7 @@ #include "bouncer.h" #include "pam.h" #include "scram.h" +#include "common/builtins.h" #include @@ -267,11 +268,15 @@ static bool finish_set_pool(PgSocket *client, bool takeover) if (client->sbuf.tls) { char infobuf[96] = ""; tls_get_connection_info(client->sbuf.tls, infobuf, sizeof infobuf); - slog_info(client, "login attempt: db=%s user=%s tls=%s", - client->db->name, client->login_user->name, infobuf); + slog_info(client, "login attempt: db=%s user=%s tls=%s replication=%s", + client->db->name, + client->login_user->name, + infobuf, + replication_type_parameters[client->replication]); } else { - slog_info(client, "login attempt: db=%s user=%s tls=no", - client->db->name, client->login_user->name); + slog_info(client, "login attempt: db=%s user=%s tls=no replication=%s", + client->db->name, client->login_user->name, + replication_type_parameters[client->replication]); } } @@ -288,8 +293,13 @@ static bool finish_set_pool(PgSocket *client, bool takeover) auth = cf_auth_type; if (auth == AUTH_HBA) { - auth = hba_eval(parsed_hba, &client->remote_addr, !!client->sbuf.tls, - client->db->name, client->login_user->name); + auth = hba_eval( + parsed_hba, + &client->remote_addr, + !!client->sbuf.tls, + client->replication, + client->db->name, + client->login_user->name); } if (auth == AUTH_MD5) { @@ -595,6 +605,25 @@ static bool set_startup_options(PgSocket *client, const char *options) char arg_buf[400]; struct MBuf arg; const char *position = options; + + if (client->replication) { + /* + * Since replication clients will be bound 1-to-1 to a server + * connection, we can support any configuration flags and + * fields in the options startup parameter. Because we can + * simply send the exact same value for the options parameter + * when opening the replication connection to the server. This + * allows us to also support GUCs that don't have the + * GUC_REPORT flag, specifically extra_float_digits which is a + * configuration that is set by CREATE SUBSCRIPTION in the + * options parameter. + */ + client->startup_options = strdup(options); + if (!client->startup_options) + disconnect_client(client, true, "out of memory"); + return true; + } + mbuf_init_fixed_writer(&arg, arg_buf, sizeof(arg_buf)); slog_debug(client, "received options: %s", options); @@ -658,12 +687,52 @@ static void set_appname(PgSocket *client, const char *app_name) } } +/* + * set_replication sets the replication field on the client according the given + * replicationString. + */ +static bool set_replication(PgSocket *client, const char *replicationString) +{ + bool replicationBool = false; + if (strcmp(replicationString, "database") == 0) { + client->replication = REPLICATION_LOGICAL; + return true; + } + if (!parse_bool(replicationString, &replicationBool)) { + return false; + } + client->replication = replicationBool ? REPLICATION_PHYSICAL : REPLICATION_NONE; + return true; +} + static bool decide_startup_pool(PgSocket *client, PktHdr *pkt) { const char *username = NULL, *dbname = NULL; const char *key, *val; bool ok; bool appname_found = false; + unsigned original_read_pos = pkt->data.read_pos; + + /* + * First check if we're dealing with a replication connection. Because for + * those we support some additional things when parsing the startup + * parameters, specifically we support any arguments in the options startup + * packet. + */ + while (1) { + ok = mbuf_get_string(&pkt->data, &key); + if (!ok || *key == 0) + break; + ok = mbuf_get_string(&pkt->data, &val); + if (!ok) + break; + if (strcmp(key, "replication") == 0) { + slog_debug(client, "got var: %s=%s", key, val); + set_replication(client, val); + } + } + + pkt->data.read_pos = original_read_pos; while (1) { ok = mbuf_get_string(&pkt->data, &key); @@ -685,6 +754,8 @@ static bool decide_startup_pool(PgSocket *client, PktHdr *pkt) } else if (strcmp(key, "application_name") == 0) { set_appname(client, val); appname_found = true; + } else if (strcmp(key, "replication") == 0) { + /* do nothing, already checked in the previous loop */ } else if (varcache_set(&client->vars, key, val)) { slog_debug(client, "got var: %s=%s", key, val); } else if (strlist_contains(cf_ignore_startup_params, key)) { diff --git a/src/common/bool.c b/src/common/bool.c new file mode 100644 index 000000000000..b0543e1c58d2 --- /dev/null +++ b/src/common/bool.c @@ -0,0 +1,112 @@ +/*------------------------------------------------------------------------- + * + * bool.c + * Functions for the built-in type "bool". + * + * Portions Copyright (c) 1996-2023, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/utils/adt/bool.c + * + *------------------------------------------------------------------------- + */ + +#include "common/postgres_compat.h" + +#include "common/builtins.h" + +/* + * Try to interpret value as boolean value. Valid values are: true, + * false, yes, no, on, off, 1, 0; as well as unique prefixes thereof. + * If the string parses okay, return true, else false. + * If okay and result is not NULL, return the value in *result. + */ +bool +parse_bool(const char *value, bool *result) +{ + return parse_bool_with_len(value, strlen(value), result); +} + +bool +parse_bool_with_len(const char *value, size_t len, bool *result) +{ + switch (*value) + { + case 't': + case 'T': + if (pg_strncasecmp(value, "true", len) == 0) + { + if (result) + *result = true; + return true; + } + break; + case 'f': + case 'F': + if (pg_strncasecmp(value, "false", len) == 0) + { + if (result) + *result = false; + return true; + } + break; + case 'y': + case 'Y': + if (pg_strncasecmp(value, "yes", len) == 0) + { + if (result) + *result = true; + return true; + } + break; + case 'n': + case 'N': + if (pg_strncasecmp(value, "no", len) == 0) + { + if (result) + *result = false; + return true; + } + break; + case 'o': + case 'O': + /* 'o' is not unique enough */ + if (pg_strncasecmp(value, "on", (len > 2 ? len : 2)) == 0) + { + if (result) + *result = true; + return true; + } + else if (pg_strncasecmp(value, "off", (len > 2 ? len : 2)) == 0) + { + if (result) + *result = false; + return true; + } + break; + case '1': + if (len == 1) + { + if (result) + *result = true; + return true; + } + break; + case '0': + if (len == 1) + { + if (result) + *result = false; + return true; + } + break; + default: + break; + } + + if (result) + *result = false; /* suppress compiler warning */ + return false; +} diff --git a/src/hba.c b/src/hba.c index 99047b41c979..5af965259483 100644 --- a/src/hba.c +++ b/src/hba.c @@ -31,8 +31,9 @@ enum RuleType { RULE_HOSTNOSSL, }; -#define NAME_ALL 1 -#define NAME_SAMEUSER 2 +#define NAME_ALL 1 +#define NAME_SAMEUSER 2 +#define NAME_REPLICATION 4 struct NameSlot { size_t strlen; @@ -399,8 +400,8 @@ static bool parse_names(struct HBAName *hname, struct TokParser *tp, bool is_db, return false; } if (eat_kw(tp, "replication")) { - log_warning("replication is not supported"); - return false; + hname->flags |= NAME_REPLICATION; + goto eat_comma; } } @@ -699,7 +700,7 @@ static bool match_inet6(const struct HBARule *rule, PgAddr *addr) (src[2] & mask[2]) == base[2] && (src[3] & mask[3]) == base[3]; } -int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, const char *dbname, const char *username) +int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, ReplicationType replication, const char *dbname, const char *username) { struct List *el; struct HBARule *rule; @@ -733,8 +734,14 @@ int hba_eval(struct HBA *hba, PgAddr *addr, bool is_tls, const char *dbname, con } /* match db & user */ - if (!name_match(&rule->db_name, dbname, dbnamelen, username)) - continue; + if (replication == REPLICATION_PHYSICAL) { + if (!(rule->db_name.flags & NAME_REPLICATION)) { + continue; + } + } else { + if (!name_match(&rule->db_name, dbname, dbnamelen, username)) + continue; + } if (!name_match(&rule->user_name, username, unamelen, dbname)) continue; diff --git a/src/janitor.c b/src/janitor.c index 2616a3c3bb68..db7c40e689c7 100644 --- a/src/janitor.c +++ b/src/janitor.c @@ -187,7 +187,14 @@ static void per_loop_activate(PgPool *pool) sv_used = statlist_count(&pool->used_server_list); statlist_for_each_safe(item, &pool->waiting_client_list, tmp) { client = container_of(item, PgSocket, head); - if (!statlist_empty(&pool->idle_server_list)) { + if (client->replication) { + /* + * For replication connections we always launch + * a new connection, but we continue with the loop, + * because there might be normal clients waiting too. + */ + launch_new_connection(pool, /* evict_if_needed= */ true); + } else if (!statlist_empty(&pool->idle_server_list)) { /* db not fully initialized after reboot */ if (client->wait_for_welcome && !pool->welcome_msg_ready) { launch_new_connection(pool, /* evict_if_needed= */ true); @@ -243,7 +250,7 @@ static int per_loop_suspend(PgPool *pool, bool force_suspend) active += suspend_socket_list(&pool->active_client_list, force_suspend); - /* this list is not suspendable, but still need force_suspend and counting */ + /* these lists are not suspendable, but still need force_suspend and counting */ active += suspend_socket_list(&pool->waiting_client_list, force_suspend); if (active) per_loop_activate(pool); @@ -534,13 +541,13 @@ static void pool_server_maint(PgPool *pool) check_unused_servers(pool, &pool->idle_server_list, 1); /* disconnect close_needed active servers if server_fast_close is set */ - if (cf_server_fast_close) { - statlist_for_each_safe(item, &pool->active_server_list, tmp) { - server = container_of(item, PgSocket, head); - Assert(server->state == SV_ACTIVE); - if (server->ready && server->close_needed) - disconnect_server(server, true, "database configuration changed"); - } + statlist_for_each_safe(item, &pool->active_server_list, tmp) { + server = container_of(item, PgSocket, head); + Assert(server->state == SV_ACTIVE); + if (cf_server_fast_close && server->ready && server->close_needed) + disconnect_server(server, true, "database configuration changed"); + if (server->replication != REPLICATION_NONE && server->close_needed) + disconnect_server(server, true, "database configuration changed"); } /* handle query_timeout and idle_transaction_timeout */ diff --git a/src/loader.c b/src/loader.c index c76484568c64..56a9717b0160 100644 --- a/src/loader.c +++ b/src/loader.c @@ -120,20 +120,6 @@ static char * cstr_get_pair(char *p, return cstr_skip_ws(p); } -/* - * Same as strcmp, but handles NULLs. If both sides are NULL, returns "true". - */ -static bool strings_equal(const char *str_left, const char *str_right) -{ - if (str_left == NULL && str_right == NULL) - return true; - - if (str_left == NULL || str_right == NULL) - return false; - - return strcmp(str_left, str_right) == 0; -} - static bool set_auth_dbname(PgDatabase *db, const char *new_auth_dbname) { if (strings_equal(db->auth_dbname, new_auth_dbname)) diff --git a/src/objects.c b/src/objects.c index ce6c5b3cb80b..87b7435ef340 100644 --- a/src/objects.c +++ b/src/objects.c @@ -73,6 +73,12 @@ static STATLIST(justfree_server_list); /* init autodb idle list */ STATLIST(autodatabase_idle_list); +const char *replication_type_parameters[] = { + [REPLICATION_NONE] = "no", + [REPLICATION_LOGICAL] = "database", + [REPLICATION_PHYSICAL] = "yes", +}; + /* fast way to get number of active clients */ int get_active_client_count(void) { @@ -695,10 +701,12 @@ PgPool *get_peer_pool(PgDatabase *db) /* deactivate socket and put into wait queue */ static void pause_client(PgSocket *client) { + SocketState newstate; Assert(client->state == CL_ACTIVE || client->state == CL_LOGIN); - slog_debug(client, "pause_client"); - change_client_state(client, CL_WAITING); + newstate = CL_WAITING; + + change_client_state(client, newstate); if (!sbuf_pause(&client->sbuf)) disconnect_client(client, true, "pause failed"); } @@ -802,9 +810,22 @@ bool find_server(PgSocket *client) if (client->link) return true; + slog_debug(client, "find_server: no linked server yet"); + /* try to get idle server, if allowed */ if (cf_pause_mode == P_PAUSE || pool->db->db_paused) { server = NULL; + } else if (client->replication) { + /* + * For replication clients we open dedicated server connections. These + * connections are linked to a client as soon as the server is ready, + * instead of lazily being assigned to a client only when the client + * sends a query. So if we reach this point we know that that has not + * happened yet, and we need to create a new replication connection for + * this client. + */ + launch_new_connection(pool, /*evict_if_needed= */ true); + server = NULL; } else { while (1) { server = first_socket(&pool->idle_server_list); @@ -859,8 +880,11 @@ static bool reuse_on_release(PgSocket *server) { bool res = true; PgPool *pool = server->pool; - PgSocket *client = first_socket(&pool->waiting_client_list); - if (client) { + PgSocket *client; + Assert(!server->replication); + slog_debug(server, "reuse_on_release: replication %d", server->replication); + client = first_socket(&pool->waiting_client_list); + if (client && !client->replication) { activate_client(client); /* @@ -932,7 +956,7 @@ bool release_server(PgSocket *server) } if (*cf_server_reset_query && (cf_server_reset_query_always || - pool_pool_mode(pool) == POOL_SESSION)) { + server_pool_mode(server) == POOL_SESSION)) { /* notify reset is required */ newstate = SV_TESTED; } else if (cf_server_check_delay == 0 && *cf_server_check_query) { @@ -986,6 +1010,18 @@ bool release_server(PgSocket *server) return true; } + if (server->replication) { + if (server->link) { + slog_debug(server, "release_server: new replication connection ready"); + change_server_state(server, SV_ACTIVE); + activate_client(server->link); + return true; + } else { + disconnect_server(server, true, "replication client was closed"); + return false; + } + } + Assert(server->link == NULL); slog_noise(server, "release_server: new state=%d", newstate); change_server_state(server, newstate); @@ -1000,6 +1036,28 @@ bool release_server(PgSocket *server) return true; } +static void unlink_server(PgSocket *server, const char *reason) +{ + PgSocket *client; + if (!server->link) + return; + + client = server->link; + + client->link = NULL; + server->link = NULL; + /* + * Send reason to client if it is already + * logged in, otherwise send generic message. + */ + if (client->state == CL_ACTIVE || client->state == CL_WAITING) + disconnect_client(client, true, "%s", reason); + else if (client->state == CL_ACTIVE_CANCEL) + disconnect_client(client, false, "successfully sent cancel request"); + else + disconnect_client(client, true, "bouncer config error"); +} + /* * close server connection * @@ -1031,26 +1089,9 @@ void disconnect_server(PgSocket *server, bool send_term, const char *reason, ... switch (server->state) { case SV_ACTIVE_CANCEL: - case SV_ACTIVE: { - PgSocket *client = server->link; - - if (client) { - client->link = NULL; - server->link = NULL; - /* - * Send reason to client if it is already - * logged in, otherwise send generic message. - */ - if (client->state == CL_ACTIVE || client->state == CL_WAITING) - disconnect_client(client, true, "%s", reason); - else if (client->state == CL_ACTIVE_CANCEL) - disconnect_client(client, false, "successfully sent cancel request"); - else - disconnect_client(client, true, "bouncer config error"); - } - + case SV_ACTIVE: + unlink_server(server, reason); break; - } case SV_TESTED: case SV_USED: case SV_IDLE: @@ -1074,6 +1115,8 @@ void disconnect_server(PgSocket *server, bool send_term, const char *reason, ... server->pool->last_connect_failed = false; send_term = false; } + if (server->replication) + unlink_server(server, reason); break; default: fatal("bad server state: %d", server->state); @@ -1225,6 +1268,16 @@ void disconnect_client_sqlstate(PgSocket *client, bool notify, const char *sqlst break; case CL_WAITING: case CL_WAITING_LOGIN: + /* + * replication connections might already be linked to a server + * while they are still in a waiting state. + */ + if (client->replication && client->link) { + PgSocket *server = client->link; + server->link = NULL; + client->link = NULL; + disconnect_server(server, false, "replication client disconnected"); + } break; default: fatal("bad client state: %d", client->state); @@ -1249,6 +1302,9 @@ void disconnect_client_sqlstate(PgSocket *client, bool notify, const char *sqlst client->db = NULL; } + free(client->startup_options); + client->startup_options = NULL; + change_client_state(client, CL_JUSTFREE); if (!sbuf_close(&client->sbuf)) log_noise("sbuf_close failed, retry later"); @@ -1443,6 +1499,21 @@ bool evict_connection(PgDatabase *db) return false; } +/* evict the oldest idle connection from the pool */ +bool evict_pool_connection(PgPool *pool) +{ + PgSocket *oldest_connection = NULL; + + oldest_connection = compare_connections_by_time(oldest_connection, last_socket(&pool->idle_server_list)); + + if (oldest_connection) { + disconnect_server(oldest_connection, true, "evicted"); + return true; + } + return false; +} + + /* evict the single most idle connection from among all pools to make room in the user */ bool evict_user_connection(PgUser *user) { @@ -1483,6 +1554,7 @@ void launch_new_connection(PgPool *pool, bool evict_if_needed) PgSocket *server; int max; + log_debug("launch_new_connection: start"); /* * Allow only a single connection attempt at a time. * @@ -1538,9 +1610,9 @@ void launch_new_connection(PgPool *pool, bool evict_if_needed) /* is it allowed to add servers? */ if (max >= pool_pool_size(pool) && pool->welcome_msg_ready) { /* should we use reserve pool? */ + PgSocket *c = first_socket(&pool->waiting_client_list); if (cf_res_pool_timeout && pool_res_pool_size(pool)) { usec_t now = get_cached_time(); - PgSocket *c = first_socket(&pool->waiting_client_list); if (c && (now - c->request_time) >= cf_res_pool_timeout) { if (max < pool_pool_size(pool) + pool_res_pool_size(pool)) { slog_warning(c, "taking connection from reserve_pool"); @@ -1548,6 +1620,15 @@ void launch_new_connection(PgPool *pool, bool evict_if_needed) } } } + + if (c && c->replication) { + while (evict_if_needed && pool_pool_size(pool) >= max) { + if (!evict_pool_connection(pool)) + break; + } + if (pool_pool_size(pool) < max) + goto allow_new; + } log_debug("launch_new_connection: pool full (%d >= %d)", max, pool_pool_size(pool)); return; diff --git a/src/proto.c b/src/proto.c index fccde842c332..2d5a182d3984 100644 --- a/src/proto.c +++ b/src/proto.c @@ -593,15 +593,69 @@ bool answer_authreq(PgSocket *server, PktHdr *pkt) bool send_startup_packet(PgSocket *server) { - PgDatabase *db = server->pool->db; - const char *username = server->pool->user->name; - PktBuf *pkt; - - pkt = pktbuf_temp(); - pktbuf_write_StartupMessage(pkt, username, - db->startup_params->buf, - db->startup_params->write_pos); - return pktbuf_send_immediate(pkt, server); + PgPool *pool = server->pool; + PgDatabase *db = pool->db; + const char *username = pool->user->name; + PktBuf *pkt = pktbuf_temp(); + PgSocket *client = NULL; + + pktbuf_start_packet(pkt, PKT_STARTUP); + pktbuf_put_bytes(pkt, db->startup_params->buf, db->startup_params->write_pos); + + /* + * If the next client in the list is a replication connection, we need + * to do some special stuff for it. + */ + client = first_socket(&pool->waiting_client_list); + if (client && client->replication) { + server->replication = client->replication; + pktbuf_put_string(pkt, "replication"); + slog_debug(server, "send_startup_packet: creating replication connection"); + pktbuf_put_string(pkt, replication_type_parameters[server->replication]); + + /* + * For a replication connection we apply the varcache in the + * startup instead of through SET commands after connecting. + * The main reason to do so is because physical replication + * connections don't allow SET commands. A second reason is + * because it allows us to skip running the SET logic + * completely, which normally requires waiting on multiple + * server responses. This SET logic is normally executed in the + * codepath where we link the client to the server + * (find_server), but because we link the client here already + * we don't run that code for replication connections. Adding + * the varcache parameters to the startup message allows us to + * skip the dance that involves sending Query packets and + * waiting for responses. + */ + varcache_apply_startup(pkt, client); + if (client->startup_options) { + pktbuf_put_string(pkt, "options"); + pktbuf_put_string(pkt, client->startup_options); + } + } + + pktbuf_put_string(pkt, "user"); + pktbuf_put_string(pkt, username); + pktbuf_put_string(pkt, ""); /* terminator required in StartupMessage */ + pktbuf_finish_packet(pkt); + + if (!pktbuf_send_immediate(pkt, server)) { + return false; + } + + if (server->replication) { + /* + * We link replication connections to a client directly when they are + * created. One reason for is because the startup parameters need to be + * forwarded, because physical replication connections don't allow SET + * commands. Another reason is so that we don't need a separate state. + */ + client->link = server; + server->link = client; + } + + return true; } bool send_sslreq_packet(PgSocket *server) diff --git a/src/server.c b/src/server.c index 398b4fe11589..cd5fab5a1017 100644 --- a/src/server.c +++ b/src/server.c @@ -122,7 +122,18 @@ static bool handle_server_startup(PgSocket *server, PktHdr *pkt) break; case 'E': /* ErrorResponse */ - if (!server->pool->welcome_msg_ready) + /* + * If we haven't been able to connect to the server since the + * startup (or call to tag_pool_dirty), then we drop all + * clients that are currently trying to log in because they + * will almost certainly hit the same error. + * + * However, we don't do this if it's a replication connection, + * because those can fail due to a variety of missing + * permissions, while normal connections would still be able + * connect and query the database just fine. + */ + if (!server->pool->welcome_msg_ready && !server->replication) kill_pool_logins_server_error(server->pool, pkt); else log_server_error("S: login failed", pkt); @@ -192,6 +203,18 @@ static bool handle_server_startup(PgSocket *server, PktHdr *pkt) return res; } +/* + * server_pool_mode returns the pool_mode for the server. It specifically + * forces session pooling if the server is a replication connection, because + * replication connections require session pooling to work correctly. + */ +int server_pool_mode(PgSocket *server) +{ + if (server->replication) + return POOL_SESSION; + return pool_pool_mode(server->pool); +} + int pool_pool_mode(PgPool *pool) { int pool_mode = pool->user->pool_mode; @@ -270,7 +293,7 @@ static bool handle_server_work(PgSocket *server, PktHdr *pkt) /* set ready only if no tx */ if (state == 'I') { ready = true; - } else if (pool_pool_mode(server->pool) == POOL_STMT) { + } else if (server_pool_mode(server) == POOL_STMT) { disconnect_server(server, true, "transaction blocks not allowed in statement pooling mode"); return false; } else if (state == 'T' || state == 'E') { @@ -343,6 +366,7 @@ static bool handle_server_work(PgSocket *server, PktHdr *pkt) /* copy mode */ case 'G': /* CopyInResponse */ case 'H': /* CopyOutResponse */ + case 'W': /* CopyBothResponse */ server->copy_mode = true; break; /* chat packets */ @@ -619,7 +643,7 @@ bool server_proto(SBuf *sbuf, SBufEvent evtype, struct MBuf *data) break; } - if (pool_pool_mode(pool) != POOL_SESSION || server->state == SV_TESTED || server->resetting) { + if (server_pool_mode(server) != POOL_SESSION || server->state == SV_TESTED || server->resetting) { server->resetting = false; switch (server->state) { case SV_ACTIVE: diff --git a/src/util.c b/src/util.c index b87b226fe93a..1615b323ac6b 100644 --- a/src/util.c +++ b/src/util.c @@ -496,3 +496,18 @@ bool check_reserved_database(const char *value) } return true; } + +/* + * Same as strcmp, but handles NULLs. If both sides are NULL, returns "true". + */ +bool strings_equal(const char *str_left, const char *str_right) +{ + if (str_left == NULL && str_right == NULL) + return true; + + if (str_left == NULL || str_right == NULL) + return false; + + return strcmp(str_left, str_right) == 0; +} + diff --git a/src/varcache.c b/src/varcache.c index 37de1709aada..1f671ec8ca4c 100644 --- a/src/varcache.c +++ b/src/varcache.c @@ -233,6 +233,21 @@ bool varcache_apply(PgSocket *server, PgSocket *client, bool *changes_p) return pktbuf_send_immediate(pkt, server); } +void varcache_apply_startup(PktBuf *pkt, PgSocket *client) +{ + const struct var_lookup *lk, *tmp; + + HASH_ITER(hh, lookup_map, lk, tmp) { + struct PStr *val = get_value(&client->vars, lk); + if (!val) + continue; + + slog_debug(client, "varcache_apply_startup: %s=%s", lk->name, val->str); + pktbuf_put_string(pkt, lk->name); + pktbuf_put_string(pkt, val->str); + } +} + void varcache_fill_unset(VarCache *src, PgSocket *dst) { struct PStr *srcval, *dstval; diff --git a/test/conftest.py b/test/conftest.py index cf55ef91e8b1..bf050b5595e5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -63,6 +63,7 @@ def pg(tmp_path_factory, cert_dir): f.write(f"ssl_cert_file='{cert}'\n") f.write(f"ssl_key_file='{key}'\n") + pg.nossl_access("replication", "trust", user="postgres") pg.nossl_access("all", "trust") pg.nossl_access("p4", "password") pg.nossl_access("p5", "md5") diff --git a/test/hba_test.c b/test/hba_test.c index 839b3e48c0db..90b0dcc436e7 100644 --- a/test/hba_test.c +++ b/test/hba_test.c @@ -65,9 +65,11 @@ static char *get_token(char **ln_p) static int hba_test_eval(struct HBA *hba, char *ln, int linenr) { - const char *addr=NULL, *user=NULL, *db=NULL, *tls=NULL, *exp=NULL; + const char *addr=NULL, *user=NULL, *db=NULL, *modifier=NULL, *exp=NULL; PgAddr pgaddr; int res; + bool tls; + ReplicationType replication; if (ln[0] == '#') return 0; @@ -75,7 +77,9 @@ static int hba_test_eval(struct HBA *hba, char *ln, int linenr) db = get_token(&ln); user = get_token(&ln); addr = get_token(&ln); - tls = get_token(&ln); + modifier = get_token(&ln); + tls = strings_equal(modifier, "tls"); + replication = strings_equal(modifier, "replication") ? REPLICATION_PHYSICAL : REPLICATION_NONE; if (!exp) return 0; if (!db || !user) @@ -84,7 +88,7 @@ static int hba_test_eval(struct HBA *hba, char *ln, int linenr) if (!pga_pton(&pgaddr, addr, 9999)) die("hbatest: invalid addr on line #%d", linenr); - res = hba_eval(hba, &pgaddr, !!tls, db, user); + res = hba_eval(hba, &pgaddr, tls, replication, db, user); if (strcmp(method2string(res), exp) == 0) { res = 0; } else { diff --git a/test/hba_test.eval b/test/hba_test.eval index 3cd28b98f234..f25f7b6cce2c 100644 --- a/test/hba_test.eval +++ b/test/hba_test.eval @@ -79,3 +79,16 @@ md5 mdb muser ff11:2::1 md5 mdb muser ff22:3::1 trust mdb muser ::1 reject mdb muser ::2 + +# replication +reject mdb muser ::1 replication +reject db userp unix replication +reject replication userp unix replication +trust db admin ::1 replication +trust replication admin ::1 replication +reject db admin ::1 +reject replication admin ::1 +trust db admin2 ::1 replication +trust replication admin2 ::1 replication +trust db2 admin2 ::1 +reject replication admin2 ::1 diff --git a/test/hba_test.rules b/test/hba_test.rules index 1efe467e8b75..03d6cfd45a5a 100644 --- a/test/hba_test.rules +++ b/test/hba_test.rules @@ -5,6 +5,7 @@ # hostssl DATABASE USER ADDRESS METHOD [OPTIONS] # hostnossl DATABASE USER ADDRESS METHOD [OPTIONS] +# The following lines test that weird whitespace does not break parsing # ws # z @@ -48,3 +49,7 @@ host mdb2 muser 128.0.0.0/1 cert host mdb muser ff11::0/16 md5 host mdb muser ff20::/12 md5 host mdb muser ::1/128 trust + +# replication +host replication admin ::1/128 trust +host db2,replication admin2 ::1/128 trust diff --git a/test/ssl/test.ini b/test/ssl/test.ini index 7d76e30c94b5..1e7158619521 100644 --- a/test/ssl/test.ini +++ b/test/ssl/test.ini @@ -1,6 +1,7 @@ [databases] p0 = port=6666 host=localhost dbname=p0 user=bouncer pool_size=2 p1 = port=6666 host=localhost dbname=p1 user=bouncer +p7a= port=6666 host=localhost dbname=p7 [pgbouncer] logfile = test.log diff --git a/test/test.ini b/test/test.ini index 6685f8e63791..31090ac8c459 100644 --- a/test/test.ini +++ b/test/test.ini @@ -25,12 +25,22 @@ p7b= port=6666 host=127.0.0.1 dbname=p7 p7c= port=6666 host=127.0.0.1 dbname=p7 p8 = port=6666 host=127.0.0.1 dbname=p0 connect_query='set enable_seqscan=off; set enable_nestloop=off' +user_passthrough = port=6666 host=127.0.0.1 dbname=p0 +user_passthrough2 = port=6666 host=127.0.0.1 dbname=p2 +user_passthrough_pool_size2 = port=6666 host=127.0.0.1 dbname=p0 pool_size=2 +user_passthrough_pool_size5 = port=6666 host=127.0.0.1 dbname=p0 pool_size=2 + pauthz = port=6666 host=127.0.0.1 dbname=p7 auth_user=pswcheck auth_dbname=authdb authdb = port=6666 host=127.0.0.1 dbname=p1 auth_user=pswcheck hostlist1 = port=6666 host=127.0.0.1,::1 dbname=p0 user=bouncer hostlist2 = port=6666 host=127.0.0.1,127.0.0.1 dbname=p0 user=bouncer +; Needed for pg_receivewal and pg_basebackup tests until this patch is merged +; in PostgreSQL (and we don't support PG16 anymore): +; https://www.postgresql.org/message-id/flat/CAGECzQTw-dZkVT_RELRzfWRzY714-VaTjoBATYfZq93R8C-auA@mail.gmail.com +replication = port=6666 host=127.0.0.1 dbname=replication + ; commented out except for auto-database tests ;* = port=6666 host=127.0.0.1 diff --git a/test/test_limits.py b/test/test_limits.py index 1daf9afdfac3..401fa2261826 100644 --- a/test/test_limits.py +++ b/test/test_limits.py @@ -1,4 +1,5 @@ import asyncio +import time import psycopg import pytest @@ -63,7 +64,7 @@ def test_min_pool_size_with_lower_max_user_connections(bouncer): # Running a query for sufficient time for us to reach the final # connection count in the pool and detect any evictions. - with bouncer.log_contains("new connection to server", times=2): + with bouncer.log_contains(r"new connection to server \(from", times=2): with bouncer.log_contains("closing because: evicted", times=0): bouncer.sleep(2, dbname="p0x", user="maxedout2") @@ -76,7 +77,7 @@ def test_min_pool_size_with_lower_max_db_connections(bouncer): # Running a query for sufficient time for us to reach the final # connection count in the pool and detect any evictions. - with bouncer.log_contains("new connection to server", times=2): + with bouncer.log_contains(r"new connection to server \(from", times=2): with bouncer.log_contains("closing because: evicted", times=0): bouncer.sleep(2, dbname="p0y", user="puser1") diff --git a/test/test_replication.py b/test/test_replication.py new file mode 100644 index 000000000000..a8bbda8e8c6d --- /dev/null +++ b/test/test_replication.py @@ -0,0 +1,280 @@ +import asyncio +import signal +import subprocess +import time + +import psycopg +import psycopg.errors +import pytest +from psycopg import sql + +from .utils import PG_MAJOR_VERSION, WINDOWS, run + + +def test_logical_rep(bouncer): + connect_args = { + "dbname": "user_passthrough", + "replication": "database", + "user": "postgres", + "application_name": "abc", + "options": "-c enable_seqscan=off", + } + # Starting in PG10 you can do other commands over logical rep connections + if PG_MAJOR_VERSION >= 10: + bouncer.test(**connect_args) + assert bouncer.sql_value("SHOW application_name", **connect_args) == "abc" + assert bouncer.sql_value("SHOW enable_seqscan", **connect_args) == "off" + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + # Do a normal connection to the same pool, to ensure that that doesn't + # break anything + bouncer.test(dbname="user_passthrough", user="postgres") + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + + +def test_logical_rep_unprivileged(bouncer): + if PG_MAJOR_VERSION < 10: + expected_log = "no pg_hba.conf entry for replication connection" + elif PG_MAJOR_VERSION < 16: + expected_log = "must be superuser or replication role to start walsender" + else: + expected_log = "permission denied to start WAL sender" + + with bouncer.log_contains( + expected_log, + ), bouncer.log_contains( + r"closing because: login failed \(age", times=2 + ), pytest.raises(psycopg.OperationalError, match=r"login failed"): + bouncer.sql("IDENTIFY_SYSTEM", replication="database") + + +@pytest.mark.skipif( + "PG_MAJOR_VERSION < 10", reason="logical replication was introduced in PG10" +) +def test_logical_rep_subscriber(bouncer): + bouncer.admin("set pool_mode=transaction") + + # First write create a table and insert a row in the source database. + # Also create the replication slot and publication + bouncer.default_db = "user_passthrough" + bouncer.create_schema("test_logical_rep_subscriber") + bouncer.sql("CREATE TABLE test_logical_rep_subscriber.table(a int)") + bouncer.sql("INSERT INTO test_logical_rep_subscriber.table values (1)") + assert ( + bouncer.sql_value("SELECT count(*) FROM test_logical_rep_subscriber.table") == 1 + ) + + bouncer.create_publication( + "mypub", sql.SQL("FOR TABLE test_logical_rep_subscriber.table") + ) + + bouncer.create_logical_replication_slot("test_logical_rep_subscriber", "pgoutput") + + # Create an equivalent, but empty schema in the target database. + # And setup the subscription + bouncer.default_db = "user_passthrough2" + bouncer.create_schema("test_logical_rep_subscriber") + bouncer.sql("CREATE TABLE test_logical_rep_subscriber.table(a int)") + conninfo = bouncer.make_conninfo(dbname="user_passthrough") + bouncer.create_subscription( + "mysub", + sql.SQL( + """ + CONNECTION {} + PUBLICATION mypub + WITH (slot_name=test_logical_rep_subscriber, create_slot=false) + """ + ).format(sql.Literal(conninfo)), + ) + + # The initial copy should now copy over the row + time.sleep(2) + assert ( + bouncer.sql_value("SELECT count(*) FROM test_logical_rep_subscriber.table") >= 1 + ) + + # Insert another row and logical replication should replicate it correctly + bouncer.sql( + "INSERT INTO test_logical_rep_subscriber.table values (2)", + dbname="user_passthrough", + ) + time.sleep(2) + assert ( + bouncer.sql_value("SELECT count(*) FROM test_logical_rep_subscriber.table") >= 2 + ) + + +@pytest.mark.skipif( + "WINDOWS", reason="MINGW does not have contrib package containing test_decoding" +) +def test_logical_rep_pg_recvlogical(bouncer): + bouncer.default_db = "user_passthrough" + bouncer.create_schema("test_logical_rep_pg_recvlogical") + bouncer.sql("CREATE TABLE test_logical_rep_pg_recvlogical.table(a int)") + bouncer.create_logical_replication_slot( + "test_logical_rep_pg_recvlogical", "test_decoding" + ) + process = subprocess.Popen( + [ + "pg_recvlogical", + "--dbname", + bouncer.default_db, + "--host", + bouncer.host, + "--port", + str(bouncer.port), + "--user", + bouncer.default_user, + "--slot=test_logical_rep_pg_recvlogical", + "--file=-", + "--no-loop", + "--start", + ], + stdout=subprocess.PIPE, + ) + assert process.stdout is not None + bouncer.sql("INSERT INTO test_logical_rep_pg_recvlogical.table values (1)") + try: + assert process.stdout.readline().startswith(b"BEGIN ") + assert ( + process.stdout.readline() + == b'table test_logical_rep_pg_recvlogical."table": INSERT: a[integer]:1\n' + ) + assert process.stdout.readline().startswith(b"COMMIT ") + finally: + process.kill() + process.communicate(timeout=5) + + +def test_physical_rep(bouncer): + connect_args = { + "dbname": "user_passthrough", + "replication": "yes", + "user": "postgres", + "application_name": "abc", + "options": "-c enable_seqscan=off", + } + # Starting in PG10 you can do SHOW commands + if PG_MAJOR_VERSION >= 10: + with pytest.raises( + psycopg.errors.FeatureNotSupported, + match="cannot execute SQL commands in WAL sender for physical replication", + ): + bouncer.test(**connect_args) + assert bouncer.sql_value("SHOW application_name", **connect_args) == "abc" + assert bouncer.sql_value("SHOW enable_seqscan", **connect_args) == "off" + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + # Do a normal connection to the same pool, to ensure that that doesn't + # break anything + bouncer.test(dbname="user_passthrough", user="postgres") + bouncer.sql("IDENTIFY_SYSTEM", **connect_args) + + +def test_physcal_rep_unprivileged(bouncer): + with bouncer.log_contains( + r"no pg_hba.conf entry for replication connection from host" + ), bouncer.log_contains( + r"closing because: login failed \(age", times=2 + ), pytest.raises( + psycopg.OperationalError, match=r"login failed" + ): + bouncer.test(replication="yes") + + +@pytest.mark.skipif("PG_MAJOR_VERSION < 10", reason="pg_receivewal was added in PG10") +def test_physical_rep_pg_receivewal(bouncer, tmp_path): + bouncer.default_db = "user_passthrough" + bouncer.create_physical_replication_slot("test_physical_rep_pg_receivewal") + wal_dump_dir = tmp_path / "wal-dump" + wal_dump_dir.mkdir() + + process = subprocess.Popen( + [ + "pg_receivewal", + "--dbname", + bouncer.make_conninfo(), + "--slot=test_physical_rep_pg_receivewal", + "--directory", + str(wal_dump_dir), + ], + ) + time.sleep(3) + + if WINDOWS: + process.terminate() + else: + process.send_signal(signal.SIGINT) + process.communicate(timeout=5) + + if WINDOWS: + assert process.returncode == 1 + else: + assert process.returncode == 0 + + children = list(wal_dump_dir.iterdir()) + assert len(children) > 0 + + +def test_physical_rep_pg_basebackup(bouncer, tmp_path): + bouncer.default_db = "user_passthrough" + dump_dir = tmp_path / "db-dump" + dump_dir.mkdir() + + run( + [ + "pg_basebackup", + "--dbname", + bouncer.make_conninfo(), + "--checkpoint=fast", + "--pgdata", + str(dump_dir), + ], + shell=False, + ) + children = list(dump_dir.iterdir()) + assert len(children) > 0 + print(children) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + "PG_MAJOR_VERSION < 10", + reason="normal SQL commands are only supported in PG10+ on logical replication connections", +) +async def test_replication_pool_size(pg, bouncer): + connect_args = { + "dbname": "user_passthrough_pool_size2", + "replication": "database", + "user": "postgres", + } + start = time.time() + await bouncer.asleep(0.5, times=10, **connect_args) + assert time.time() - start > 2.5 + # Replication connections always get closed right away + assert pg.connection_count("p0") == 0 + + connect_args["dbname"] = "user_passthrough_pool_size5" + start = time.time() + await bouncer.asleep(0.5, times=10, **connect_args) + assert time.time() - start > 1 + # Replication connections always get closed right away + assert pg.connection_count("p0") == 0 + + +@pytest.mark.asyncio +@pytest.mark.skipif( + "PG_MAJOR_VERSION < 10", + reason="normal SQL commands are only supported in PG10+ on logical replication connections", +) +async def test_replication_pool_size_mixed_clients(bouncer): + connect_args = { + "dbname": "user_passthrough_pool_size2", + "user": "postgres", + } + + # Fill the pool with normal connections + await bouncer.asleep(0.5, times=2, **connect_args) + + # Then try to open a replication connection and ensure that it causes + # eviction of one of the normal connections + with bouncer.log_contains("closing because: evicted"): + bouncer.test(**connect_args, replication="database") diff --git a/test/test_ssl.py b/test/test_ssl.py index 3a82c91e1a53..b7a3976ddb07 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -329,3 +329,41 @@ def test_client_ssl_scram(bouncer, cert_dir): sslmode="verify-full", sslrootcert=root, ) + + +def test_ssl_replication(pg, bouncer, cert_dir): + root = cert_dir / "TestCA1" / "ca.crt" + key = cert_dir / "TestCA1" / "sites" / "01-localhost.key" + cert = cert_dir / "TestCA1" / "sites" / "01-localhost.crt" + + bouncer.write_ini(f"server_tls_sslmode = verify-full") + bouncer.write_ini(f"server_tls_ca_file = {root}") + bouncer.write_ini(f"client_tls_sslmode = require") + bouncer.write_ini(f"client_tls_key_file = {key}") + bouncer.write_ini(f"client_tls_cert_file = {cert}") + bouncer.write_ini(f"client_tls_ca_file = {root}") + bouncer.admin("reload") + pg.ssl_access("all", "trust") + pg.ssl_access("replication", "trust", user="postgres") + pg.configure("ssl=on") + pg.configure(f"ssl_ca_file='{root}'") + + if PG_MAJOR_VERSION < 10 or WINDOWS: + pg.restart() + else: + pg.reload() + + # Logical rep + connect_args = { + "host": "localhost", + "dbname": "p7a", + "replication": "database", + "user": "postgres", + "application_name": "abc", + "sslmode": "verify-full", + "sslrootcert": root, + } + bouncer.psql("IDENTIFY_SYSTEM", **connect_args) + # physical rep + connect_args["replication"] = "true" + bouncer.psql("IDENTIFY_SYSTEM", **connect_args) diff --git a/test/utils.py b/test/utils.py index e02a131ee46e..709c50a0b69f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -21,6 +21,8 @@ import filelock import psycopg +import psycopg.sql +from psycopg import sql TEST_DIR = Path(os.path.dirname(os.path.realpath(__file__))) os.chdir(TEST_DIR) @@ -158,6 +160,38 @@ def get_tls_support(): next_port = PORT_LOWER_BOUND +def notice_handler(diag: psycopg.errors.Diagnostic): + print(f"{diag.severity} ({diag.sqlstate}): {diag.message_primary}") + if diag.message_detail: + print(f"DETAIL: {diag.message_detail}") + if diag.message_hint: + print(f"HINT: {diag.message_hint}") + + +def cleanup_test_leftovers(*nodes): + """ + Cleaning up test leftovers needs to be done in a specific order, because + some of these leftovers depend on others having been removed. They might + even depend on leftovers on other nodes being removed. So this takes a list + of nodes, so that we can clean up all test leftovers globally in the + correct order. + """ + for node in nodes: + node.cleanup_subscriptions() + + for node in nodes: + node.cleanup_publications() + + for node in nodes: + node.cleanup_replication_slots() + + for node in nodes: + node.cleanup_schemas() + + for node in nodes: + node.cleanup_users() + + class PortLock: def __init__(self): global next_port @@ -191,9 +225,19 @@ def __init__(self, host, port): self.default_db = "postgres" self.default_user = "postgres" + # Used to track objects that we want to clean up at the end of a test + self.subscriptions = set() + self.publications = set() + self.replication_slots = set() + self.schemas = set() + self.users = set() + def set_default_connection_options(self, options): + """Sets the default connection options on the given options dictionary""" options.setdefault("dbname", self.default_db) options.setdefault("user", self.default_user) + options.setdefault("host", self.host) + options.setdefault("port", self.port) if ENABLE_VALGRIND: # If valgrind is enabled PgBouncer is a significantly slower to # respond to connection requests, so we wait a little longer. @@ -202,24 +246,27 @@ def set_default_connection_options(self, options): options.setdefault("connect_timeout", 3) # needed for Ubuntu 18.04 options.setdefault("client_encoding", "UTF8") + return options + + def make_conninfo(self, **kwargs) -> str: + self.set_default_connection_options(kwargs) + return psycopg.conninfo.make_conninfo(**kwargs) def conn(self, *, autocommit=True, **kwargs): """Open a psycopg connection to this server""" self.set_default_connection_options(kwargs) - return psycopg.connect( + conn = psycopg.connect( autocommit=autocommit, - host=self.host, - port=self.port, **kwargs, ) + conn.add_notice_handler(notice_handler) + return conn def aconn(self, *, autocommit=True, **kwargs): """Open an asynchronous psycopg connection to this server""" self.set_default_connection_options(kwargs) return psycopg.AsyncConnection.connect( autocommit=autocommit, - host=self.host, - port=self.port, **kwargs, ) @@ -448,6 +495,117 @@ def reject_traffic(self): elif BSD: sudo(f"pfctl -a pgbouncer_test/port_{self.port} -F all") + def create_user(self, name, args: typing.Optional[psycopg.sql.Composable] = None): + self.users.add(name) + if args is None: + args = sql.SQL("") + self.sql(sql.SQL("CREATE USER {} {}").format(sql.Identifier(name), args)) + + def create_schema(self, name, dbname=None): + dbname = dbname or self.default_db + self.schemas.add((dbname, name)) + self.sql(sql.SQL("CREATE SCHEMA {}").format(sql.Identifier(name))) + + def create_publication(self, name: str, args: psycopg.sql.Composable, dbname=None): + dbname = dbname or self.default_db + self.publications.add((dbname, name)) + self.sql(sql.SQL("CREATE PUBLICATION {} {}").format(sql.Identifier(name), args)) + + def create_logical_replication_slot(self, name, plugin): + self.replication_slots.add(name) + self.sql( + "SELECT pg_catalog.pg_create_logical_replication_slot(%s,%s)", + (name, plugin), + ) + + def create_physical_replication_slot(self, name): + self.replication_slots.add(name) + self.sql( + "SELECT pg_catalog.pg_create_physical_replication_slot(%s)", + (name,), + ) + + def create_subscription(self, name: str, args: psycopg.sql.Composable, dbname=None): + dbname = dbname or self.default_db + self.subscriptions.add((dbname, name)) + self.sql( + sql.SQL("CREATE SUBSCRIPTION {} {}").format(sql.Identifier(name), args) + ) + + def cleanup_users(self): + for user in self.users: + self.sql(sql.SQL("DROP USER IF EXISTS {}").format(sql.Identifier(user))) + + def cleanup_schemas(self): + for dbname, schema in self.schemas: + self.sql( + sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE").format( + sql.Identifier(schema) + ), + dbname=dbname, + ) + + def cleanup_publications(self): + for dbname, publication in self.publications: + self.sql( + sql.SQL("DROP PUBLICATION IF EXISTS {}").format( + sql.Identifier(publication) + ), + dbname=dbname, + ) + + def cleanup_replication_slots(self): + for slot in self.replication_slots: + start = time.time() + while True: + try: + self.sql( + "SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name = %s", + (slot,), + ) + except psycopg.errors.ObjectInUse: + if time.time() < start + 10: + time.sleep(0.5) + continue + raise + break + + def cleanup_subscriptions(self): + for dbname, subscription in self.subscriptions: + try: + self.sql( + sql.SQL("ALTER SUBSCRIPTION {} DISABLE").format( + sql.Identifier(subscription) + ), + dbname=dbname, + ) + except psycopg.errors.UndefinedObject: + # Subscription didn't exist already + continue + self.sql( + sql.SQL("ALTER SUBSCRIPTION {} SET (slot_name = NONE)").format( + sql.Identifier(subscription) + ), + dbname=dbname, + ) + self.sql( + sql.SQL("DROP SUBSCRIPTION {}").format(sql.Identifier(subscription)), + dbname=dbname, + ) + + def debug(self): + print("Connect manually to:\n ", repr(self.make_conninfo())) + print("Press Enter to continue running the test...") + input() + + def psql_debug(self, **kwargs): + conninfo = self.make_conninfo(**kwargs) + run( + ["psql", f"{conninfo}"], + shell=False, + silent=True, + ) + class Postgres(QueryRunner): def __init__(self, pgdata): @@ -471,6 +629,26 @@ def initdb(self): pgconf.write("log_connections = on\n") pgconf.write("log_disconnections = on\n") pgconf.write("logging_collector = off\n") + + # Allow CREATE SUBSCRIPTION to work + pgconf.write("wal_level = 'logical'\n") + # Faster logical replication status update so tests with logical replication + # run faster + pgconf.write("wal_receiver_status_interval = 1\n") + + # Faster logical replication apply worker launch so tests with logical + # replication run faster. This is used in ApplyLauncherMain in + # src/backend/replication/logical/launcher.c. + pgconf.write("wal_retrieve_retry_interval = '250ms'\n") + + # Make sure there's enough logical replication resources for our + # tests + if PG_MAJOR_VERSION >= 10: + pgconf.write("max_logical_replication_workers = 5\n") + pgconf.write("max_wal_senders = 5\n") + pgconf.write("max_replication_slots = 10\n") + pgconf.write("max_worker_processes = 20\n") + # We need to make the log go to stderr so that the tests can # check what is being logged. This should be the default, but # some packagings change the default configuration. @@ -524,24 +702,24 @@ async def arestart(self): process = await self.apgctl("-m fast restart") await process.communicate() - def nossl_access(self, dbname, auth_type): + def nossl_access(self, dbname, auth_type, user="all"): """Prepends a local non-SSL access to the HBA file""" with self.hba_path.open() as pghba: old_contents = pghba.read() with self.hba_path.open(mode="w") as pghba: if USE_UNIX_SOCKETS: - pghba.write(f"local {dbname} all {auth_type}\n") - pghba.write(f"hostnossl {dbname} all 127.0.0.1/32 {auth_type}\n") - pghba.write(f"hostnossl {dbname} all ::1/128 {auth_type}\n") + pghba.write(f"local {dbname} {user} {auth_type}\n") + pghba.write(f"hostnossl {dbname} {user} 127.0.0.1/32 {auth_type}\n") + pghba.write(f"hostnossl {dbname} {user} ::1/128 {auth_type}\n") pghba.write(old_contents) - def ssl_access(self, dbname, auth_type): + def ssl_access(self, dbname, auth_type, user="all"): """Prepends a local SSL access rule to the HBA file""" with self.hba_path.open() as pghba: old_contents = pghba.read() with self.hba_path.open(mode="w") as pghba: - pghba.write(f"hostssl {dbname} all 127.0.0.1/32 {auth_type}\n") - pghba.write(f"hostssl {dbname} all ::1/128 {auth_type}\n") + pghba.write(f"hostssl {dbname} {user} 127.0.0.1/32 {auth_type}\n") + pghba.write(f"hostssl {dbname} {user} ::1/128 {auth_type}\n") pghba.write(old_contents) @property @@ -663,6 +841,8 @@ def __init__( ini.write(f"logfile = {self.log_path}\n") ini.write(f"auth_file = {self.auth_path}\n") ini.write("pidfile = \n") + # Uncomment for much more noise but, more detailed debugging + # ini.write("verbose = 3\n") if not USE_UNIX_SOCKETS: ini.write(f"unix_socket_dir = \n") @@ -807,8 +987,11 @@ def print_logs(self): assert not failed_valgrind async def cleanup(self): - await self.stop() - self.print_logs() + try: + cleanup_test_leftovers(self) + await self.stop() + finally: + self.print_logs() if self.port_lock: self.port_lock.release()