Skip to content

Commit

Permalink
feat(db_api): support JSON data type (#627)
Browse files Browse the repository at this point in the history
Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
  • Loading branch information
Ilya Gurov and larkee committed Nov 22, 2021
1 parent d769ff8 commit d760c2c
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 25 deletions.
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -223,6 +223,7 @@ def execute(self, sql, args=None):
ResultsChecksum(),
classification == parse_utils.STMT_INSERT,
)

(self._result_set, self._checksum,) = self.connection.run_statement(
statement
)
Expand Down
8 changes: 3 additions & 5 deletions google/cloud/spanner_v1/_helpers.py
Expand Up @@ -17,7 +17,6 @@
import datetime
import decimal
import math
import json

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand Down Expand Up @@ -166,9 +165,8 @@ def _make_value_pb(value):
_assert_numeric_precision_and_scale(value)
return Value(string_value=str(value))
if isinstance(value, JsonObject):
return Value(
string_value=json.dumps(value, sort_keys=True, separators=(",", ":"),)
)
return Value(string_value=value.serialize())

raise ValueError("Unknown type: %s" % (value,))


Expand Down Expand Up @@ -243,7 +241,7 @@ def _parse_value_pb(value_pb, field_type):
elif type_code == TypeCode.NUMERIC:
return decimal.Decimal(value_pb.string_value)
elif type_code == TypeCode.JSON:
return value_pb.string_value
return JsonObject.from_str(value_pb.string_value)
else:
raise ValueError("Unknown type: %s" % (field_type,))

Expand Down
33 changes: 32 additions & 1 deletion google/cloud/spanner_v1/data_types.py
Expand Up @@ -14,6 +14,8 @@

"""Custom data types for spanner."""

import json


class JsonObject(dict):
"""
Expand All @@ -22,4 +24,33 @@ class JsonObject(dict):
normal parameters and JSON parameters.
"""

pass
def __init__(self, *args, **kwargs):
self._is_null = (args, kwargs) == ((), {}) or args == (None,)
if not self._is_null:
super(JsonObject, self).__init__(*args, **kwargs)

@classmethod
def from_str(cls, str_repr):
"""Initiate an object from its `str` representation.
Args:
str_repr (str): JSON text representation.
Returns:
JsonObject: JSON object.
"""
if str_repr == "null":
return cls()

return cls(json.loads(str_repr))

def serialize(self):
"""Return the object text representation.
Returns:
str: JSON object text representation.
"""
if self._is_null:
return None

return json.dumps(self, sort_keys=True, separators=(",", ":"))
8 changes: 4 additions & 4 deletions samples/samples/snippets_test.py
Expand Up @@ -50,13 +50,13 @@ def sample_name():

@pytest.fixture(scope="module")
def create_instance_id():
""" Id for the low-cost instance. """
"""Id for the low-cost instance."""
return f"create-instance-{uuid.uuid4().hex[:10]}"


@pytest.fixture(scope="module")
def lci_instance_id():
""" Id for the low-cost instance. """
"""Id for the low-cost instance."""
return f"lci-instance-{uuid.uuid4().hex[:10]}"


Expand Down Expand Up @@ -91,7 +91,7 @@ def database_ddl():

@pytest.fixture(scope="module")
def default_leader():
""" Default leader for multi-region instances. """
"""Default leader for multi-region instances."""
return "us-east4"


Expand Down Expand Up @@ -582,7 +582,7 @@ def test_update_data_with_json(capsys, instance_id, sample_database):
def test_query_data_with_json_parameter(capsys, instance_id, sample_database):
snippets.query_data_with_json_parameter(instance_id, sample_database.database_id)
out, _ = capsys.readouterr()
assert "VenueId: 19, VenueDetails: {\"open\":true,\"rating\":9}" in out
assert "VenueId: 19, VenueDetails: {'open': True, 'rating': 9}" in out


@pytest.mark.dependency(depends=["insert_datatypes_data"])
Expand Down
2 changes: 1 addition & 1 deletion tests/system/test_dbapi.py
Expand Up @@ -364,7 +364,7 @@ def test_autocommit_with_json_data(shared_instance, dbapi_database):
# Assert the response
assert len(got_rows) == 1
assert got_rows[0][0] == 123
assert got_rows[0][1] == '{"age":"26","name":"Jakob"}'
assert got_rows[0][1] == {"age": "26", "name": "Jakob"}

# Drop the table
cur.execute("DROP TABLE JsonDetails")
Expand Down
14 changes: 4 additions & 10 deletions tests/system/test_session_api.py
Expand Up @@ -19,7 +19,6 @@
import struct
import threading
import time
import json
import pytest

import grpc
Expand All @@ -28,6 +27,7 @@
from google.api_core import exceptions
from google.cloud import spanner_v1
from google.cloud._helpers import UTC
from google.cloud.spanner_v1.data_types import JsonObject
from tests import _helpers as ot_helpers
from . import _helpers
from . import _sample_data
Expand All @@ -43,23 +43,17 @@
BYTES_2 = b"Ym9vdHM="
NUMERIC_1 = decimal.Decimal("0.123456789")
NUMERIC_2 = decimal.Decimal("1234567890")
JSON_1 = json.dumps(
JSON_1 = JsonObject(
{
"sample_boolean": True,
"sample_int": 872163,
"sample float": 7871.298,
"sample_null": None,
"sample_string": "abcdef",
"sample_array": [23, 76, 19],
},
sort_keys=True,
separators=(",", ":"),
)
JSON_2 = json.dumps(
{"sample_object": {"name": "Anamika", "id": 2635}},
sort_keys=True,
separators=(",", ":"),
}
)
JSON_2 = JsonObject({"sample_object": {"name": "Anamika", "id": 2635}},)

COUNTERS_TABLE = "counters"
COUNTERS_COLUMNS = ("name", "value")
Expand Down
16 changes: 12 additions & 4 deletions tests/unit/test__helpers.py
Expand Up @@ -567,14 +567,22 @@ def test_w_json(self):
from google.cloud.spanner_v1 import Type
from google.cloud.spanner_v1 import TypeCode

VALUE = json.dumps(
{"id": 27863, "Name": "Anamika"}, sort_keys=True, separators=(",", ":")
)
VALUE = {"id": 27863, "Name": "Anamika"}
str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":"))

field_type = Type(code=TypeCode.JSON)
value_pb = Value(string_value=VALUE)
value_pb = Value(string_value=str_repr)

self.assertEqual(self._callFUT(value_pb, field_type), VALUE)

VALUE = None
str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":"))

field_type = Type(code=TypeCode.JSON)
value_pb = Value(string_value=str_repr)

self.assertEqual(self._callFUT(value_pb, field_type), {})

def test_w_unknown_type(self):
from google.protobuf.struct_pb2 import Value
from google.cloud.spanner_v1 import Type
Expand Down

0 comments on commit d760c2c

Please sign in to comment.