Skip to content

Commit

Permalink
Merge pull request #23 from jwills/jwills_powerbi
Browse files Browse the repository at this point in the history
Add support for doing binary encoding of result rows for clients that send that info (like PowerBI)
  • Loading branch information
jwills committed Oct 12, 2023
2 parents 5133d53 + 10dae39 commit 57c7359
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 38 deletions.
1 change: 1 addition & 0 deletions buenavista/backends/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class DuckDBQueryResult(QueryResult):
def __init__(
self, rbr: Optional[pa.RecordBatchReader] = None, status: Optional[str] = None
):
super().__init__()
if rbr:
self.rbr = rbr
self.bvtypes = [to_bvtype(s.type) for s in rbr.schema]
Expand Down
1 change: 1 addition & 0 deletions buenavista/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
rows: List[List[Optional[Any]]],
status: Optional[str] = None,
):
super().__init__()
self.fields = fields
self._rows = rows
self._status = status
Expand Down
4 changes: 4 additions & 0 deletions buenavista/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class BVType(enum.Enum):
class QueryResult:
"""The BV representation of a result of a query."""

def __init__(self):
self.result_format = None

def has_results(self) -> bool:
raise NotImplementedError

Expand Down Expand Up @@ -113,6 +116,7 @@ def apply(self, params: dict, session: Session) -> QueryResult:

class SimpleQueryResult(QueryResult):
def __init__(self, name: str, value: Any, type: BVType):
super().__init__()
self.name = name
self.value = str(value)
self.type = type
Expand Down
130 changes: 100 additions & 30 deletions buenavista/postgres.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import hashlib
import io
import json
Expand Down Expand Up @@ -59,27 +60,65 @@ class ClientCommand:
TERMINATE = b"X"


def _time_to_microseconds(t):
# Convert hours, minutes, seconds, and microseconds to microseconds
hours_to_microseconds = t.hour * 60 * 60 * 1e6
minutes_to_microseconds = t.minute * 60 * 1e6
seconds_to_microseconds = t.second * 1e6
microseconds = t.microsecond
total_microseconds = (
hours_to_microseconds
+ minutes_to_microseconds
+ seconds_to_microseconds
+ microseconds
)
return int(total_microseconds)


def _micros_since_2000(dt):
micros = (
dt - datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc)
).total_seconds() * 1000000
return int(micros)


PG_UNKNOWN = (705, str)
BVTYPE_TO_PGTYPE = {
BVType.NULL: (-1, lambda v: None),
BVType.ARRAY: (2277, lambda v: "{" + ",".join(v) + "}"),
BVType.BIGINT: (20, str),
BVType.BOOL: (16, lambda v: "true" if v else "false"),
BVType.BYTES: (17, lambda v: "\\x" + v.hex()),
BVType.DATE: (1082, lambda v: v.isoformat()),
BVType.ARRAY: (2277, lambda v: "{" + ",".join(v) + "}", None),
BVType.BIGINT: (20, str, lambda r: int.to_bytes(r, 8, "big")),
BVType.BOOL: (
16,
lambda v: "true" if v else "false",
lambda r: b"\x01" if r else b"\x00",
),
BVType.BYTES: (17, lambda v: "\\x" + v.hex(), lambda r: r),
BVType.DATE: (
1082,
lambda v: v.isoformat(),
lambda r: int.to_bytes((r.toordinal() - 730120), 4, "big"),
),
BVType.DECIMAL: (1700, str),
BVType.FLOAT: (701, str),
BVType.INTEGER: (23, str),
BVType.FLOAT: (701, str, lambda r: struct.pack("!d", r)),
BVType.INTEGER: (23, str, lambda r: int.to_bytes(r, 4, "big")),
BVType.INTEGERARRAY: (1007, lambda v: "{" + ",".join(v) + "}"),
BVType.INTERVAL: (
1186,
lambda v: f"{v.days} days {v.seconds} seconds {v.microseconds} microseconds",
),
BVType.JSON: (114, lambda v: json.dumps(v)),
BVType.STRINGARRAY: (1009, lambda v: "{" + ",".join(v) + "}"),
BVType.TEXT: (25, str),
BVType.TIME: (1083, lambda v: v.isoformat()),
BVType.TIMESTAMP: (1114, lambda v: v.isoformat().replace("T", " ")),
BVType.TEXT: (25, str, lambda r: r.encode("utf-8")),
BVType.TIME: (
1083,
lambda v: v.isoformat(),
lambda r: int.to_bytes(_time_to_microseconds(r), 8, "big"),
),
BVType.TIMESTAMP: (
1114,
lambda v: v.isoformat().replace("T", " "),
lambda r: int.to_bytes(_micros_since_2000(r), 8, "big"),
),
}


Expand Down Expand Up @@ -164,22 +203,30 @@ def transaction_status(self):
return TransactionStatus.IN_TRANSACTION
return TransactionStatus.IDLE

def execute_sql(self, sql: str, params=None) -> QueryResult:
def execute_sql(self, sql: str, params=None, result_fmt=None) -> QueryResult:
logger.info("Input SQL: " + sql)
if self.rewriter:
sql = self.rewriter.rewrite(sql)
logger.info("Rewritten SQL: " + sql)
return self.session.execute_sql(sql, params)
qr = self.session.execute_sql(sql, params)
if qr.has_results():
if result_fmt and len(result_fmt) != qr.column_count():
qr.result_format = [result_fmt[0]] * qr.column_count()
else:
qr.result_format = result_fmt
return qr

def describe_portal(self, name: str) -> QueryResult:
stmt, params = self.portals[name]
sql = self.stmts[stmt]
query_result = self.execute_sql(sql=sql, params=params)
stmt, params, result_fmt = self.portals[name]
sql, param_oids = self.stmts[stmt]
# todo: parse params? LIMIT 0?
query_result = self.execute_sql(sql=sql, params=params, result_fmt=result_fmt)
self.result_cache[name] = query_result
return query_result

def describe_statement(self, name: str) -> QueryResult:
sql = self.stmts[name]
sql, param_oids = self.stmts[name]
# TODO: create default params from param_oids
return self.execute_sql(sql)

def execute_portal(self, name: str) -> QueryResult:
Expand All @@ -188,18 +235,22 @@ def execute_portal(self, name: str) -> QueryResult:
del self.result_cache[name]
return query_result
else:
stmt, params = self.portals[name]
sql = self.stmts[stmt]
return self.execute_sql(sql=sql, params=params)
stmt, params, result_fmt = self.portals[name]
sql, param_oids = self.stmts[stmt]
# parse the params?
qr = self.execute_sql(sql=sql, params=params, result_fmt=result_fmt)
return qr

def add_statement(self, name: str, sql: str):
self.stmts[name] = sql
def add_statement(self, name: str, sql: str, param_oids: List[int]):
self.stmts[name] = (sql, param_oids)

def close_statement(self, name: str):
del self.stmts[name]

def add_portal(self, name: str, stmt: str, params: Dict[str, str]):
self.portals[name] = (stmt, params)
def add_portal(
self, name: str, stmt: str, params: Dict[str, str], result_formats: List[int]
):
self.portals[name] = (stmt, params, result_formats)

def close_portal(self, name: str):
del self.portals[name]
Expand Down Expand Up @@ -346,7 +397,9 @@ def handle_query(self, ctx: BVContext, payload: bytes):
self.send_ready_for_query(ctx)
return

if query_result.has_results():
if not query_result:
raise Exception("No query result for query: " + decoded)
elif query_result.has_results():
self.send_row_description(query_result)
row_count = self.send_data_rows(query_result)
self.send_command_complete("SELECT %d\x00" % row_count)
Expand All @@ -363,7 +416,12 @@ def handle_parse(self, ctx: BVContext, payload: bytes):
query_idx = ba.index(NULL_BYTE, stmt_idx + 1)
sql = ba[stmt_idx + 1 : query_idx].decode("utf-8")
logger.debug("Parsed statement: %s", sql)
ctx.add_statement(stmt, sql)
buf = BVBuffer(io.BytesIO(ba[query_idx + 1 :]))
num_params = buf.read_int16()
param_oids = []
for i in range(num_params):
param_oids.append(buf.read_int32())
ctx.add_statement(stmt, sql, param_oids)
self.send_parse_complete()

def handle_bind(self, ctx: BVContext, payload: bytes):
Expand Down Expand Up @@ -401,7 +459,12 @@ def handle_bind(self, ctx: BVContext, payload: bytes):
# ints but I can live with it for now
params.append(int.from_bytes(v, "big"))
logger.debug("Bind params: %s", params)
ctx.add_portal(portal, stmt, params)
# now expected result formats
num_result_formats = buf.read_int16()
result_formats = []
for i in range(num_result_formats):
result_formats.append(buf.read_int16())
ctx.add_portal(portal, stmt, params, result_formats)
self.send_bind_complete()

def handle_describe(self, ctx: BVContext, payload: bytes):
Expand Down Expand Up @@ -468,7 +531,8 @@ def send_row_description(self, query_result: QueryResult):
name, bvtype = query_result.column(i)
oid = BVTYPE_TO_PGTYPE.get(bvtype, PG_UNKNOWN)[0]
buf.write_string(name)
buf.write_bytes(struct.pack("!ihihih", 0, 0, oid, 0, -1, 0))
fmt = query_result.result_format[i] if query_result.result_format else 0
buf.write_bytes(struct.pack("!ihihih", 0, 0, oid, 0, -1, fmt))
out = buf.get_value()
sig = struct.pack(
"!cih",
Expand All @@ -483,14 +547,20 @@ def send_data_rows(self, query_result: QueryResult, limit: int = 0) -> int:
converters = []
for i in range(query_result.column_count()):
bvtype = query_result.column(i)[1]
converters.append(BVTYPE_TO_PGTYPE.get(bvtype, PG_UNKNOWN)[1])
pgtype = BVTYPE_TO_PGTYPE.get(bvtype, PG_UNKNOWN)
if not query_result.result_format or query_result.result_format[i] == 0:
txt_fn = pgtype[1]
c = lambda r: txt_fn(r).encode("utf-8")
else:
c = pgtype[2]
converters.append(c)
for row in query_result.rows():
buf = BVBuffer()
for i, r in enumerate(row):
for j, r in enumerate(row):
if r is None:
buf.write_int32(-1)
else:
v = converters[i](r).encode("utf-8")
v = converters[j](r)
buf.write_int32(len(v))
buf.write_bytes(v)
out = buf.get_value()
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ fastapi[all]
psycopg
psycopg-pool
pyarrow
pydantic>=1.2.0,<2.0.0
pytest
sqlglot
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
include_package_data=True,
install_requires=[
"fastapi",
"pydantic",
"pydantic>=1.2.0,<2.0.0",
"sqlglot",
],
extras_require={
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/postgres/test_bv_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def test_bv_context_transaction_status(bv_context, mock_session):
def test_bv_context_add_close_statement(bv_context):
name = "stmt1"
sql = "SELECT * FROM test;"
bv_context.add_statement(name, sql)
assert bv_context.stmts[name] == sql
param_oids = []
bv_context.add_statement(name, sql, param_oids)
assert bv_context.stmts[name] == (sql, param_oids)

bv_context.close_statement(name)
assert name not in bv_context.stmts
Expand All @@ -61,9 +62,9 @@ def test_bv_context_add_close_portal(bv_context):
portal_name = "portal1"
stmt_name = "stmt1"
params = {"param1": "value1"}

bv_context.add_portal(portal_name, stmt_name, params)
assert bv_context.portals[portal_name] == (stmt_name, params)
result_fmt = [0]
bv_context.add_portal(portal_name, stmt_name, params, result_fmt)
assert bv_context.portals[portal_name] == (stmt_name, params, result_fmt)

bv_context.close_portal(portal_name)
assert portal_name not in bv_context.portals
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/postgres/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_handle_query(mock_handler):

def test_handle_parse(mock_handler):
ctx = MagicMock(spec=BVContext)
mock_handler.handle_parse(ctx, b"stmt1\x00SELECT 1;\x00")
ctx.add_statement.assert_called_once_with("stmt1", "SELECT 1;")
mock_handler.handle_parse(ctx, b"stmt1\x00SELECT 1;\x00\x00\x00")
ctx.add_statement.assert_called_once_with("stmt1", "SELECT 1;", [])


# Add more test cases for other methods in the BuenaVistaHandler class

0 comments on commit 57c7359

Please sign in to comment.