Skip to content

Commit

Permalink
feat: support JSON data type (#135)
Browse files Browse the repository at this point in the history
* feat: support JSON data type

* fix type

* bug fixes

* erase excess test override

* erase excess override

* fix errors

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
Co-authored-by: skuruppu <skuruppu@google.com>
  • Loading branch information
3 people committed Dec 6, 2021
1 parent f02a2c0 commit 184a7d5
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 4 deletions.
4 changes: 4 additions & 0 deletions google/cloud/sqlalchemy_spanner/requirements.py
Expand Up @@ -17,6 +17,10 @@


class Requirements(SuiteRequirements): # pragma: no cover
@property
def json_type(self):
return exclusions.open()

@property
def computed_columns(self):
return exclusions.open()
Expand Down
62 changes: 62 additions & 0 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Expand Up @@ -34,9 +34,13 @@
GenericTypeCompiler,
IdentifierPreparer,
SQLCompiler,
OPERATORS,
RESERVED_WORDS,
)
from sqlalchemy.sql.default_comparator import operator_lookup
from sqlalchemy.sql.operators import json_getitem_op

from google.cloud.spanner_v1.data_types import JsonObject
from google.cloud import spanner_dbapi
from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call

Expand All @@ -47,6 +51,10 @@ def reset_connection(dbapi_conn, connection_record):
dbapi_conn.connection.staleness = None


# register a method to get a single value of a JSON object
OPERATORS[json_getitem_op] = operator_lookup["json_getitem_op"]


# Spanner-to-SQLAlchemy types map
_type_map = {
"BOOL": types.Boolean,
Expand All @@ -60,8 +68,10 @@ def reset_connection(dbapi_conn, connection_record):
"TIME": types.TIME,
"TIMESTAMP": types.TIMESTAMP,
"ARRAY": types.ARRAY,
"JSON": types.JSON,
}


_type_map_inv = {
types.Boolean: "BOOL",
types.BINARY: "BYTES(MAX)",
Expand Down Expand Up @@ -210,6 +220,53 @@ def visit_like_op_binary(self, binary, operator, **kw):
binary.right._compiler_dispatch(self, **kw),
)

def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw):
"""The method is overriden to process JSON data type cases."""
_in_binary = kw.get("_in_binary", False)

kw["_in_binary"] = True

if isinstance(opstring, str):
text = (
binary.left._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)
+ opstring
+ binary.right._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)
)
if _in_binary and eager_grouping:
text = "(%s)" % text
else:
# got JSON data
right_value = getattr(
binary.right, "value", None
) or binary.right._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)

text = (
binary.left._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)
+ """, "$."""
+ str(right_value)
+ '"'
)
text = "JSON_VALUE(%s)" % text

return text

def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
"""Build a JSON_VALUE() function call."""
expr = """JSON_VALUE(%s, "$.%s")"""

return expr % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
)

def render_literal_value(self, value, type_):
"""Render the value of a bind parameter as a quoted literal.
Expand Down Expand Up @@ -404,6 +461,9 @@ def visit_NUMERIC(self, type_, **kw):
def visit_BIGINT(self, type_, **kw):
return "INT64"

def visit_JSON(self, type_, **kw):
return "JSON"


class SpannerDialect(DefaultDialect):
"""Cloud Spanner dialect.
Expand Down Expand Up @@ -434,6 +494,8 @@ class SpannerDialect(DefaultDialect):
statement_compiler = SpannerSQLCompiler
type_compiler = SpannerTypeCompiler
execution_ctx_cls = SpannerExecutionContext
_json_serializer = JsonObject
_json_deserializer = JsonObject

@classmethod
def dbapi(cls):
Expand Down
134 changes: 130 additions & 4 deletions test/test_suite.py
Expand Up @@ -20,6 +20,7 @@
import os
import pkg_resources
import pytest
import random
import unittest
from unittest import mock

Expand Down Expand Up @@ -61,7 +62,6 @@
)

from google.api_core.datetime_helpers import DatetimeWithNanoseconds

from google.cloud import spanner_dbapi

from sqlalchemy.testing.suite.test_cte import * # noqa: F401, F403
Expand Down Expand Up @@ -98,15 +98,17 @@
)
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403
_DateFixture as _DateFixtureTest,
_LiteralRoundTripFixture,
_UnicodeFixture as _UnicodeFixtureTest,
BooleanTest as _BooleanTest,
DateTest as _DateTest,
_DateFixture as _DateFixtureTest,
DateTimeHistoricTest,
DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest,
DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest,
DateTimeTest as _DateTimeTest,
IntegerTest as _IntegerTest,
_LiteralRoundTripFixture,
JSONTest as _JSONTest,
NumericTest as _NumericTest,
StringTest as _StringTest,
TextTest as _TextTest,
Expand All @@ -115,7 +117,6 @@
TimestampMicrosecondsTest,
UnicodeVarcharTest as _UnicodeVarcharTest,
UnicodeTextTest as _UnicodeTextTest,
_UnicodeFixture as _UnicodeFixtureTest,
)
from test._helpers import get_db_url

Expand Down Expand Up @@ -1751,3 +1752,128 @@ def test_get_column_returns_computed(self):
is_true("computed" in compData)
is_true("sqltext" in compData["computed"])
eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")


@pytest.mark.skipif(
bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator"
)
class JSONTest(_JSONTest):
@pytest.mark.skip("Values without keys are not supported.")
def test_single_element_round_trip(self, element):
pass

def _test_round_trip(self, data_element):
data_table = self.tables.data_table

config.db.execute(
data_table.insert(),
{"id": random.randint(1, 100000000), "name": "row1", "data": data_element},
)

row = config.db.execute(select([data_table.c.data])).first()

eq_(row, (data_element,))

def test_unicode_round_trip(self):
# note we include Unicode supplementary characters as well
with config.db.connect() as conn:
conn.execute(
self.tables.data_table.insert(),
{
"id": random.randint(1, 100000000),
"name": "r1",
"data": {
util.u("réve🐍 illé"): util.u("réve🐍 illé"),
"data": {"k1": util.u("drôl🐍e")},
},
},
)

eq_(
conn.scalar(select([self.tables.data_table.c.data])),
{
util.u("réve🐍 illé"): util.u("réve🐍 illé"),
"data": {"k1": util.u("drôl🐍e")},
},
)

@pytest.mark.skip("Parameterized types are not supported.")
def test_eval_none_flag_orm(self):
pass

@pytest.mark.skip(
"Spanner JSON_VALUE() always returns STRING,"
"thus, this test case can't be executed."
)
def test_index_typed_comparison(self):
pass

@pytest.mark.skip(
"Spanner JSON_VALUE() always returns STRING,"
"thus, this test case can't be executed."
)
def test_path_typed_comparison(self):
pass

@pytest.mark.skip("Custom JSON de-/serializers are not supported.")
def test_round_trip_custom_json(self):
pass

def _index_fixtures(fn):
fn = testing.combinations(
("boolean", True),
("boolean", False),
("boolean", None),
("string", "some string"),
("string", None),
("integer", 15),
("integer", 1),
("integer", 0),
("integer", None),
("float", 28.5),
("float", None),
id_="sa",
)(fn)
return fn

@_index_fixtures
def test_index_typed_access(self, datatype, value):
data_table = self.tables.data_table
data_element = {"key1": value}
with config.db.connect() as conn:
conn.execute(
data_table.insert(),
{
"id": random.randint(1, 100000000),
"name": "row1",
"data": data_element,
"nulldata": data_element,
},
)

expr = data_table.c.data["key1"]
expr = getattr(expr, "as_%s" % datatype)()

roundtrip = conn.scalar(select([expr]))
if roundtrip in ("true", "false", None):
roundtrip = str(roundtrip).capitalize()

eq_(str(roundtrip), str(value))

@pytest.mark.skip(
"Spanner doesn't support type casts inside JSON_VALUE() function."
)
def test_round_trip_json_null_as_json_null(self):
pass

@pytest.mark.skip(
"Spanner doesn't support type casts inside JSON_VALUE() function."
)
def test_round_trip_none_as_json_null(self):
pass

@pytest.mark.skip(
"Spanner doesn't support type casts inside JSON_VALUE() function."
)
def test_round_trip_none_as_sql_null(self):
pass

0 comments on commit 184a7d5

Please sign in to comment.