diff --git a/django_spanner/__init__.py b/django_spanner/__init__.py index 0d88ffac91..cdbfa9a3ef 100644 --- a/django_spanner/__init__.py +++ b/django_spanner/__init__.py @@ -12,12 +12,14 @@ from uuid import uuid4 import pkg_resources +from google.cloud.spanner_v1 import JsonObject from django.db.models.fields import ( AutoField, SmallAutoField, BigAutoField, Field, ) +from django.db.models import JSONField # Monkey-patch google.DatetimeWithNanoseconds's __eq__ compare against # datetime.datetime. @@ -59,6 +61,17 @@ def autofield_init(self, *args, **kwargs): SmallAutoField.validators = [] BigAutoField.validators = [] + +def get_prep_value(self, value): + # Json encoding and decoding for spanner is done in python-spanner. + if not isinstance(value, JsonObject) and isinstance(value, dict): + return JsonObject(value) + + return value + + +JSONField.get_prep_value = get_prep_value + old_datetimewithnanoseconds_eq = getattr( DatetimeWithNanoseconds, "__eq__", None ) diff --git a/django_spanner/base.py b/django_spanner/base.py index 00033bf223..25c42416a5 100644 --- a/django_spanner/base.py +++ b/django_spanner/base.py @@ -34,6 +34,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): "DateField": "DATE", "DateTimeField": "TIMESTAMP", "DecimalField": "NUMERIC", + "JSONField": "JSON", "DurationField": "INT64", "EmailField": "STRING(%(max_length)s)", "FileField": "STRING(%(max_length)s)", diff --git a/django_spanner/features.py b/django_spanner/features.py index 2417b0b95f..86120329e8 100644 --- a/django_spanner/features.py +++ b/django_spanner/features.py @@ -8,6 +8,7 @@ from django.db.backends.base.features import BaseDatabaseFeatures from django.db.utils import InterfaceError +from django_spanner import USE_EMULATOR class DatabaseFeatures(BaseDatabaseFeatures): @@ -34,8 +35,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_column_check_constraints = True supports_table_check_constraints = True supports_order_by_nulls_modifier = False - # Spanner does not support json - supports_json_field = False + if USE_EMULATOR: + # Emulator does not support json. + supports_json_field = False + else: + supports_json_field = True supports_primitives_in_json_field = False # Spanner does not support SELECTing an arbitrary expression that also # appears in the GROUP BY clause. @@ -67,7 +71,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): "model_fields.test_autofield.SmallAutoFieldTests.test_redundant_backend_range_validators", # Spanner does not support deferred unique constraints "migrations.test_operations.OperationTests.test_create_model_with_deferred_unique_constraint", - # Spanner does not support JSON objects + # Spanner does not support JSON object query on fields. "db_functions.comparison.test_json_object.JSONObjectTests.test_empty", "db_functions.comparison.test_json_object.JSONObjectTests.test_basic", "db_functions.comparison.test_json_object.JSONObjectTests.test_expressions", @@ -268,17 +272,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): "timezones.tests.NewDatabaseTests.test_query_datetimes", # using NULL with + crashes: https://github.com/googleapis/python-spanner-django/issues/201 "annotations.tests.NonAggregateAnnotationTestCase.test_combined_annotation_commutative", - # Spanner loses DecimalField precision due to conversion to float: - # https://github.com/googleapis/python-spanner-django/pull/133#pullrequestreview-328482925 - "aggregation.tests.AggregateTestCase.test_decimal_max_digits_has_no_effect", - "aggregation.tests.AggregateTestCase.test_related_aggregate", + # Spanner does not support custom precision on DecimalField "db_functions.comparison.test_cast.CastTests.test_cast_to_decimal_field", "model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding", "model_fields.test_decimalfield.DecimalFieldTests.test_roundtrip_with_trailing_zeros", - # Spanner does not support unsigned integer field. - "model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values", - # Spanner doesn't support the variance the standard deviation database - # functions: + # Spanner doesn't support the variance the standard deviation database functions on full population. "aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_numerical_aggregates", "aggregation_regress.tests.AggregationTests.test_stddev", # SELECT list expression references which is neither grouped @@ -358,12 +356,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): "transaction_hooks.tests.TestConnectionOnCommit.test_discards_hooks_from_rolled_back_savepoint", "transaction_hooks.tests.TestConnectionOnCommit.test_inner_savepoint_rolled_back_with_outer", "transaction_hooks.tests.TestConnectionOnCommit.test_inner_savepoint_does_not_affect_outer", - # Spanner doesn't support views. - "inspectdb.tests.InspectDBTransactionalTests.test_include_views", - "introspection.tests.IntrospectionTests.test_table_names_with_views", - # Fields: JSON, GenericIPAddressField are mapped to String in Spanner + # Field: GenericIPAddressField is mapped to String in Spanner "inspectdb.tests.InspectDBTestCase.test_field_types", - "inspectdb.tests.InspectDBTestCase.test_json_field", # BigIntegerField is mapped to IntegerField in Spanner "inspectdb.tests.InspectDBTestCase.test_number_field_types", # No sequence for AutoField in Spanner. @@ -479,6 +473,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): if os.environ.get("SPANNER_EMULATOR_HOST", None): # Some code isn't yet supported by the Spanner emulator. skip_tests += ( + # Views are not supported by emulator + "inspectdb.tests.InspectDBTransactionalTests.test_include_views", # noqa + "introspection.tests.IntrospectionTests.test_table_names_with_views", # noqa # Untyped parameters are not supported: # https://github.com/GoogleCloudPlatform/cloud-spanner-emulator#features-and-limitations "auth_tests.test_views.PasswordResetTest.test_confirm_custom_reset_url_token_link_redirects_to_set_password_page", # noqa @@ -1588,7 +1585,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "queries.tests.Queries1Tests.test_ticket2306", # noqa "queries.tests.Queries1Tests.test_ticket2400", # noqa "queries.tests.Queries1Tests.test_ticket2496", # noqa - # "queries.tests.Queries1Tests.test_ticket2902", # noqa "queries.tests.Queries1Tests.test_ticket3037", # noqa "queries.tests.Queries1Tests.test_ticket3141", # noqa "queries.tests.Queries1Tests.test_ticket4358", # noqa @@ -1812,7 +1808,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "sitemaps_tests.test_http.HTTPSitemapTests.test_paged_sitemap", # noqa "sitemaps_tests.test_http.HTTPSitemapTests.test_requestsite_sitemap", # noqa "sitemaps_tests.test_http.HTTPSitemapTests.test_simple_custom_sitemap", # noqa - # "sitemaps_tests.test_http.HTTPSitemapTests.test_simple_i18nsitemap_index", # noqa "sitemaps_tests.test_http.HTTPSitemapTests.test_alternate_i18n_sitemap_index", # noqa "sitemaps_tests.test_http.HTTPSitemapTests.test_alternate_i18n_sitemap_limited", # noqa "sitemaps_tests.test_http.HTTPSitemapTests.test_alternate_i18n_sitemap_xdefault", # noqa diff --git a/django_spanner/introspection.py b/django_spanner/introspection.py index 95db6723d5..a9ff28b0d7 100644 --- a/django_spanner/introspection.py +++ b/django_spanner/introspection.py @@ -11,6 +11,7 @@ ) from django.db.models import Index from google.cloud.spanner_v1 import TypeCode +from django_spanner import USE_EMULATOR class DatabaseIntrospection(BaseDatabaseIntrospection): @@ -25,7 +26,28 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): TypeCode.STRING: "CharField", TypeCode.TIMESTAMP: "DateTimeField", TypeCode.NUMERIC: "DecimalField", + TypeCode.JSON: "JSONField", } + if USE_EMULATOR: + # Emulator does not support table_type yet. + # https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/issues/43 + LIST_TABLE_SQL = """ + SELECT + t.table_name, t.table_name + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = '' + """ + else: + LIST_TABLE_SQL = """ + SELECT + t.table_name, t.table_type + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = '' + """ def get_field_type(self, data_type, description): """A hook for a Spanner database to use the cursor description to @@ -53,8 +75,15 @@ def get_table_list(self, cursor): :rtype: list :returns: A list of table and view names in the current database. """ + results = cursor.run_sql_in_snapshot(self.LIST_TABLE_SQL) + tables = [] # The second TableInfo field is 't' for table or 'v' for view. - return [TableInfo(row[0], "t") for row in cursor.list_tables()] + for row in results: + table_type = "t" + if row[1] == "VIEW": + table_type = "v" + tables.append(TableInfo(row[0], table_type)) + return tables def get_table_description(self, cursor, table_name): """Return a description of the table with the DB-API cursor.description diff --git a/setup.py b/setup.py index 25e19457b3..ed485ec086 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ # 'Development Status :: 4 - Beta' # 'Development Status :: 5 - Production/Stable' release_status = "Development Status :: 4 - Beta" -dependencies = ["sqlparse >= 0.3.0", "google-cloud-spanner >= 3.0.0"] +dependencies = ["sqlparse >= 0.3.0", "google-cloud-spanner >= 3.11.1"] extras = { "tracing": [ "opentelemetry-api >= 1.1.0", diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 7573802344..8bb73e0b1b 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -6,7 +6,7 @@ # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 sqlparse==0.3.0 -google-cloud-spanner==3.0.0 +google-cloud-spanner==3.11.1 opentelemetry-api==1.1.0 opentelemetry-sdk==1.1.0 opentelemetry-instrumentation==0.20b0 diff --git a/tests/system/django_spanner/models.py b/tests/system/django_spanner/models.py index f7153ba994..edf0a807e6 100644 --- a/tests/system/django_spanner/models.py +++ b/tests/system/django_spanner/models.py @@ -34,3 +34,7 @@ class Meta: name="check_start_date", ), ] + + +class Detail(models.Model): + value = models.JSONField() diff --git a/tests/system/django_spanner/test_json_field.py b/tests/system/django_spanner/test_json_field.py new file mode 100644 index 0000000000..ab19ef6f63 --- /dev/null +++ b/tests/system/django_spanner/test_json_field.py @@ -0,0 +1,48 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from .models import Detail +import unittest +from django.test import TransactionTestCase +from django.db import connection +from django_spanner import USE_EMULATOR +from tests.system.django_spanner.utils import ( + setup_instance, + teardown_instance, + setup_database, + teardown_database, +) + + +@unittest.skipIf(USE_EMULATOR, "Jsonfield is not implemented in emulator.") +class TestJsonField(TransactionTestCase): + @classmethod + def setUpClass(cls): + setup_instance() + setup_database() + with connection.schema_editor() as editor: + # Create the tables + editor.create_model(Detail) + + @classmethod + def tearDownClass(cls): + with connection.schema_editor() as editor: + # delete the table + editor.delete_model(Detail) + teardown_database() + teardown_instance() + + def test_insert_and_fetch_value(self): + """ + Tests model object creation with Detail model. + Inserting json data into the model and retrieving it. + """ + json_data = Detail(value={"name": "Jakob", "age": "26"}) + json_data.save() + qs1 = Detail.objects.all() + self.assertEqual(qs1[0].value, {"name": "Jakob", "age": "26"}) + # Delete data from Detail table. + Detail.objects.all().delete() diff --git a/tests/unit/django_spanner/test_introspection.py b/tests/unit/django_spanner/test_introspection.py index 03b5b67ca9..a86c65cdbb 100644 --- a/tests/unit/django_spanner/test_introspection.py +++ b/tests/unit/django_spanner/test_introspection.py @@ -49,9 +49,9 @@ def test_get_table_list(self): cursor = mock.MagicMock() def list_tables(*args, **kwargs): - return [["Table_1"], ["Table_2"]] + return [["Table_1", "t"], ["Table_2", "t"]] - cursor.list_tables = list_tables + cursor.run_sql_in_snapshot = list_tables table_list = db_introspection.get_table_list(cursor=cursor) self.assertEqual( table_list,