Skip to content

Commit

Permalink
Merge pull request #17 from jwills/jwills_gpt4_write_tedious_postgres…
Browse files Browse the repository at this point in the history
…_auth_stuff

Okay it turns out I had to write most of the tedious auth stuff but at least it works now
  • Loading branch information
jwills committed May 9, 2023
2 parents fc9f7ec + 0d3dc19 commit eab23b7
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ that illustrates the idea locally:

```sh
pip3 install buenavista
python3 -m buenavista.backends.duckdb <optional_duckdb_file>
python3 -m buenavista.examples.duckdb_postgres <optional_duckdb_file>
```

in order to start a Postgres server on `localhost:5433` backed by the DuckDB database file that you passed in at the command line
Expand Down
4 changes: 3 additions & 1 deletion buenavista/examples/duckdb_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def rewrite(self, sql: str) -> str:

address = (bv_host, bv_port)

server = postgres.BuenaVistaServer(address, DuckDBConnection(db), rewriter=rewriter)
server = postgres.BuenaVistaServer(
address, DuckDBConnection(db), rewriter=rewriter, auth=None
)
ip, port = server.server_address
print(f"Listening on {ip}:{port}")

Expand Down
64 changes: 54 additions & 10 deletions buenavista/postgres.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import io
import json
import logging
Expand Down Expand Up @@ -53,6 +54,7 @@ class ClientCommand:
FLUSH = b"H"
QUERY = b"Q"
PARSE = b"P"
PASSWORD_MESSAGE = b"p"
SYNC = b"S"
TERMINATE = b"X"

Expand Down Expand Up @@ -142,6 +144,14 @@ def __init__(
self.portals = {}
self.result_cache = {}
self.has_error = False
self.authenticated = False
self.salt = None

def get_hashed_password(self, auth: dict) -> str:
user = self.params["user"]
password = auth[user]
first = hashlib.md5(password.encode("utf-8") + user.encode("utf-8")).hexdigest()
return "md5" + hashlib.md5(first.encode("utf-8") + self.salt).hexdigest()

def mark_error(self):
self.has_error = True
Expand Down Expand Up @@ -222,7 +232,12 @@ def handle(self):
else:
payload = None

if type_code == ClientCommand.QUERY:
if not ctx.authenticated:
if type_code == ClientCommand.PASSWORD_MESSAGE:
self.handle_md5_password(ctx, payload)
else:
raise Exception("Not authenticated")
elif type_code == ClientCommand.QUERY:
self.handle_query(ctx, payload)
elif type_code == ClientCommand.PARSE:
self.handle_parse(ctx, payload)
Expand Down Expand Up @@ -264,10 +279,7 @@ def handle_startup(self, conn: Connection) -> BVContext:
params = dict(zip(msg[::2], msg[1::2]))
logger.info("Client connection params: %s", params)
ctx = BVContext(conn.create_session(), self.server.rewriter, params)
self.send_authentication_ok()
self.send_parameter_status(conn.parameters())
self.send_backend_key_data(ctx)
self.send_ready_for_query(ctx)
self.send_auth_request(ctx)
return ctx
elif code == 80877102: ## Cancel request
process_id, secret_key = self.r.read_uint32(), self.r.read_uint32()
Expand All @@ -280,6 +292,41 @@ def handle_startup(self, conn: Connection) -> BVContext:
else:
raise Exception(f"Unsupported startup message: {code}")

def send_auth_request(self, ctx: BVContext):
if self.server.auth is None:
self.send_authentication_ok()
self.handle_post_auth(ctx)
else:
self.send_authentication_md5(ctx)

def send_authentication_ok(self):
self.wfile.write(
struct.pack("!cii", ServerResponse.AUTHENTICATION_REQUEST, 8, 0)
)

def send_authentication_md5(self, ctx: BVContext):
ctx.salt = os.urandom(4)
self.wfile.write(
struct.pack("!cii", ServerResponse.AUTHENTICATION_REQUEST, 12, 5)
)
self.wfile.write(ctx.salt)

def handle_md5_password(self, ctx: BVContext, payload: bytes):
client_side = payload.decode("utf-8").rstrip("\x00")
server_side = ctx.get_hashed_password(self.server.auth)
if client_side == server_side:
self.send_authentication_ok()
self.handle_post_auth(ctx)
else:
self.send_error("Invalid password")

def handle_post_auth(self, ctx: BVContext):
self.send_parameter_status(self.server.conn.parameters())
self.send_backend_key_data(ctx)
self.send_ready_for_query(ctx)
ctx.authenticated = True
return

def handle_query(self, ctx: BVContext, payload: bytes):
logger.debug("Handle query")
decoded = payload.decode("utf-8").rstrip("\x00")
Expand Down Expand Up @@ -475,11 +522,6 @@ def send_error(self, exception, ctx: Optional[BVContext] = None):
def send_notice(self):
self.wfile.write(ServerResponse.NOTICE_RESPONSE)

def send_authentication_ok(self):
self.wfile.write(
struct.pack("!cii", ServerResponse.AUTHENTICATION_REQUEST, 8, 0)
)

def send_backend_key_data(self, ctx):
self.wfile.write(
struct.pack(
Expand Down Expand Up @@ -536,12 +578,14 @@ def __init__(
*,
rewriter: Optional[Rewriter] = None,
extensions: List[Extension] = [],
auth: Optional[Dict[str, str]] = None,
):
super().__init__(server_address, BuenaVistaHandler)
self.conn = conn
self.rewriter = rewriter
self.extensions = {e.type(): e for e in extensions}
self.ctxts = {}
self.auth = auth

def verify_request(self, request, client_address) -> bool:
"""Ensure all requests come from localhost until auth is in place"""
Expand Down

0 comments on commit eab23b7

Please sign in to comment.