Skip to content

Commit

Permalink
Merge pull request #16 from jwills/jwills_gpt_write_tests
Browse files Browse the repository at this point in the history
GPT-4 boilerplate tests and workflows and what not
  • Loading branch information
jwills committed May 9, 2023
2 parents 954f6e0 + 396663e commit fc9f7ec
Show file tree
Hide file tree
Showing 9 changed files with 391 additions and 3 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Buena Vista Unit Tests

on:
pull_request:
branches: [ main ]

jobs:
build:

runs-on: ubuntu-latest

strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']

steps:
- name: Check out the repository
uses: actions/checkout@v3
with:
persist-credentials: false

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r dev-requirements.txt
pip install .
- name: Run tests
run: |
pytest tests/
1 change: 1 addition & 0 deletions buenavista/bv_dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlglot.dialects import DuckDB, Postgres, Trino
from sqlglot.tokens import TokenType


# Additional expressions I need
class ToISO8601(exp.Func):
pass
Expand Down
1 change: 0 additions & 1 deletion buenavista/examples/duckdb_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def rewrite(self, sql: str) -> str:


if __name__ == "__main__":

if len(sys.argv) < 2:
print("Using in-memory DuckDB database")
db = duckdb.connect()
Expand Down
4 changes: 2 additions & 2 deletions buenavista/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def transaction_status(self):
return TransactionStatus.IDLE

def execute_sql(self, sql: str, params=None) -> QueryResult:
print("Input SQL: " + sql)
logger.info("Input SQL: " + sql)
if self.rewriter:
sql = self.rewriter.rewrite(sql)
print("Rewritten SQL: " + sql)
logger.info("Rewritten SQL: " + sql)
return self.session.execute_sql(sql, params)

def describe_portal(self, name: str) -> QueryResult:
Expand Down
61 changes: 61 additions & 0 deletions tests/postgres/test_bv_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import io
import pytest
from buenavista.postgres import BVBuffer


@pytest.fixture
def bv_buffer():
return BVBuffer()


def test_bv_buffer_init(bv_buffer):
assert isinstance(bv_buffer, BVBuffer)
assert isinstance(bv_buffer.stream, io.BytesIO)


def test_bv_buffer_read_write_bytes(bv_buffer):
data = b"test_data"
bv_buffer.write_bytes(data)
bv_buffer.stream.seek(0)
assert bv_buffer.read_bytes(len(data)) == data


def test_bv_buffer_read_write_byte(bv_buffer):
data = b"T"
bv_buffer.write_byte(data)
bv_buffer.stream.seek(0)
assert bv_buffer.read_byte() == data


def test_bv_buffer_read_write_int16(bv_buffer):
data = 12345
bv_buffer.write_int16(data)
bv_buffer.stream.seek(0)
assert bv_buffer.read_int16() == data


def test_bv_buffer_read_write_uint32(bv_buffer):
data = 12345678
bv_buffer.write_int32(data)
bv_buffer.stream.seek(0)
assert bv_buffer.read_uint32() == data


def test_bv_buffer_read_write_int32(bv_buffer):
data = -12345678
bv_buffer.write_int32(data)
bv_buffer.stream.seek(0)
assert bv_buffer.read_int32() == data


def test_bv_buffer_write_string(bv_buffer):
data = "test_string"
bv_buffer.write_string(data)
bv_buffer.stream.seek(0)
assert bv_buffer.stream.read(len(data) + 1) == data.encode() + b"\x00"


def test_bv_buffer_get_value(bv_buffer):
data = b"get_value_test"
bv_buffer.write_bytes(data)
assert bv_buffer.get_value() == data
80 changes: 80 additions & 0 deletions tests/postgres/test_bv_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
from typing import Dict
from unittest.mock import MagicMock

from buenavista.core import Session
from buenavista.postgres import BVContext, TransactionStatus


@pytest.fixture
def mock_session():
session = MagicMock(spec=Session)
session.in_transaction.return_value = False
return session


@pytest.fixture
def bv_context(mock_session):
return BVContext(session=mock_session, rewriter=None, params={})


def test_bv_context_init(bv_context, mock_session):
assert bv_context.session == mock_session
assert bv_context.rewriter is None
assert bv_context.params == {}
assert isinstance(bv_context.process_id, int)
assert isinstance(bv_context.secret_key, int)
assert bv_context.stmts == {}
assert bv_context.portals == {}
assert bv_context.result_cache == {}
assert bv_context.has_error is False


def test_bv_context_mark_error(bv_context):
bv_context.mark_error()
assert bv_context.has_error is True


def test_bv_context_transaction_status(bv_context, mock_session):
mock_session.in_transaction.return_value = False
assert bv_context.transaction_status() == TransactionStatus.IDLE

bv_context.mark_error()
mock_session.in_transaction.return_value = True
assert bv_context.transaction_status() == TransactionStatus.IN_FAILED_TRANSACTION

bv_context.has_error = False
assert bv_context.transaction_status() == TransactionStatus.IN_TRANSACTION


def test_bv_context_add_close_statement(bv_context):
name = "stmt1"
sql = "SELECT * FROM test;"
bv_context.add_statement(name, sql)
assert bv_context.stmts[name] == sql

bv_context.close_statement(name)
assert name not in bv_context.stmts


def test_bv_context_add_close_portal(bv_context):
portal_name = "portal1"
stmt_name = "stmt1"
params = {"param1": "value1"}

bv_context.add_portal(portal_name, stmt_name, params)
assert bv_context.portals[portal_name] == (stmt_name, params)

bv_context.close_portal(portal_name)
assert portal_name not in bv_context.portals


def test_bv_context_flush(bv_context):
# This method is a no-op, but including a test for completeness
bv_context.flush()


def test_bv_context_sync(bv_context):
bv_context.mark_error()
bv_context.sync()
assert bv_context.has_error is False
54 changes: 54 additions & 0 deletions tests/postgres/test_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from unittest.mock import MagicMock, patch

from buenavista.core import BVType, Connection, SimpleQueryResult
from buenavista.postgres import (
BuenaVistaHandler,
BVBuffer,
BVContext,
TransactionStatus,
)
from buenavista.rewrite import Rewriter

# Add any necessary setup or helper methods here, e.g., creating a test server and client


@pytest.fixture
def mock_handler():
request, client_address = MagicMock(), MagicMock()
server = MagicMock()
server.conn = MagicMock(spec=Connection)
server.rewriter = MagicMock(spec=Rewriter)
server.ctxts = {}

handler = BuenaVistaHandler(request, client_address, server)
handler.r = MagicMock(spec=BVBuffer)
handler.wfile = MagicMock()

return handler


def test_handle_startup(mock_handler):
mock_handler.r.read_uint32.side_effect = [8, 196608]
mock_handler.r.read_bytes.return_value = b"user\x00test\x00database\x00testdb\x00"
ctx = mock_handler.handle_startup(mock_handler.server.conn)
assert isinstance(ctx, BVContext)
assert ctx.session is not None
assert ctx.params == {"user": "test", "database": "testdb"}


def test_handle_query(mock_handler):
ctx = MagicMock(spec=BVContext)
ctx.execute_sql.return_value = SimpleQueryResult("col1", 1, BVType.INTEGER)
ctx.transaction_status.return_value = TransactionStatus.IDLE
mock_handler.handle_query(ctx, b"SELECT 1;\x00")
ctx.execute_sql.assert_called_once_with("SELECT 1;")


def test_handle_parse(mock_handler):
ctx = MagicMock(spec=BVContext)
mock_handler.handle_parse(ctx, b"stmt1\x00SELECT 1;\x00")
ctx.add_statement.assert_called_once_with("stmt1", "SELECT 1;")


# Add more test cases for other methods in the BuenaVistaHandler class
113 changes: 113 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pytest
import uuid
from typing import Iterator, List, Tuple

from buenavista.core import (
BVType,
QueryResult,
Session,
Connection,
Extension,
SimpleQueryResult,
)


# ----------------------- QueryResult -----------------------
class DummyQueryResult(QueryResult):
def has_results(self) -> bool:
return True

def column_count(self) -> int:
return 1

def column(self, index: int) -> Tuple[str, BVType]:
return ("dummy", BVType.TEXT)

def rows(self) -> Iterator[List]:
return iter([["dummy_row"]])

def status(self) -> str:
return "Dummy status"


@pytest.fixture
def dummy_query_result():
return DummyQueryResult()


def test_query_result_has_results(dummy_query_result):
assert dummy_query_result.has_results() is True


def test_query_result_column_count(dummy_query_result):
assert dummy_query_result.column_count() == 1


def test_query_result_column(dummy_query_result):
assert dummy_query_result.column(0) == ("dummy", BVType.TEXT)


def test_query_result_rows(dummy_query_result):
assert list(dummy_query_result.rows()) == [["dummy_row"]]


def test_query_result_status(dummy_query_result):
assert dummy_query_result.status() == "Dummy status"


# ----------------------- Session -----------------------
def test_session_init():
session = Session()
assert isinstance(session, Session)
assert isinstance(session.id, uuid.UUID)


# ----------------------- Connection -----------------------
def test_connection_init():
connection = Connection()
assert isinstance(connection, Connection)
assert connection._sessions == {}


# ----------------------- Extension -----------------------
def test_extension_check_json():
payload = '{"key": "value"}'
result = Extension.check_json(payload)
assert result == {"key": "value"}

payload = "not json;"
result = Extension.check_json(payload)
assert result is None


# ----------------------- SimpleQueryResult -----------------------
def test_simple_query_result_init():
sqr = SimpleQueryResult("test", 42, BVType.INTEGER)
assert sqr.name == "test"
assert sqr.value == "42"
assert sqr.type == BVType.INTEGER


def test_simple_query_result_has_results():
sqr = SimpleQueryResult("test", 42, BVType.INTEGER)
assert sqr.has_results() is True


def test_simple_query_result_column_count():
sqr = SimpleQueryResult("test", 42, BVType.INTEGER)
assert sqr.column_count() == 1


def test_simple_query_result_column():
sqr = SimpleQueryResult("test", 42, BVType.INTEGER)
assert sqr.column(0) == ("test", BVType.INTEGER)


def test_simple_query_result_rows():
sqr = SimpleQueryResult("test", 42, BVType.INTEGER)
assert list(sqr.rows()) == [["42"]]


def test_simple_query_result_status():
sqr = SimpleQueryResult("test", 42, BVType.INTEGER)
assert sqr.status() == ""

0 comments on commit fc9f7ec

Please sign in to comment.