From 410f64ba2880653480df2f4790a4403d63e9bad6 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Mon, 6 Sep 2021 12:34:53 -0600 Subject: [PATCH 1/3] fix: Arror extension-type metadata not set when calling the REST API or when there are no rows --- google/cloud/bigquery/_pandas_helpers.py | 14 ++++++- google/cloud/bigquery/table.py | 2 +- tests/system/conftest.py | 17 ++++++++ tests/system/test_arrow.py | 53 ++++++++++++++++++++++++ tests/unit/test__pandas_helpers.py | 23 ++++++++++ 5 files changed, 107 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 0a22043a3..869c0215d 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -173,6 +173,13 @@ def pyarrow_timestamp(): pyarrow.decimal128(38, scale=9).id: "NUMERIC", pyarrow.decimal256(76, scale=38).id: "BIGNUMERIC", } + BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = { + "GEOGRAPHY": { + b"ARROW:extension:name": b"google:sqlType:geography", + b"ARROW:extension:metadata": b'{"encoding": "WKT"}', + }, + "DATETIME": {b"ARROW:extension:name": b"google:sqlType:datetime"}, + } else: # pragma: NO COVER BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER @@ -227,7 +234,12 @@ def bq_to_arrow_field(bq_field, array_type=None): if array_type is not None: arrow_type = array_type # For GEOGRAPHY, at least initially is_nullable = bq_field.mode.upper() == "NULLABLE" - return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable) + metadata = BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get( + bq_field.field_type.upper() if bq_field.field_type else "" + ) + return pyarrow.field( + bq_field.name, arrow_type, nullable=is_nullable, metadata=metadata + ) warnings.warn("Unable to determine type for field '{}'.".format(bq_field.name)) return None diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 609c0b57e..dd78b211a 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1810,7 +1810,7 @@ def to_arrow( if owns_bqstorage_client: bqstorage_client._transport.grpc_channel.close() - if record_batches: + if record_batches and bqstorage_client is not None: return pyarrow.Table.from_batches(record_batches) else: # No records, use schema based on BigQuery schema. diff --git a/tests/system/conftest.py b/tests/system/conftest.py index cc2c2a4dc..7eec76a32 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. import pathlib +import re import pytest import test_utils.prefixer @@ -61,6 +62,17 @@ def dataset_id(bigquery_client): bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) +@pytest.fixture() +def dataset_client(bigquery_client, dataset_id): + import google.cloud.bigquery.job + + return bigquery.Client( + default_query_job_config=google.cloud.bigquery.job.QueryJobConfig( + default_dataset=f"{bigquery_client.project}.{dataset_id}", + ) + ) + + @pytest.fixture def table_id(dataset_id): return f"{dataset_id}.table_{helpers.temp_suffix()}" @@ -98,3 +110,8 @@ def scalars_extreme_table( job.result() yield full_table_id bigquery_client.delete_table(full_table_id) + + +@pytest.fixture +def test_table_name(request, replace_non_anum=re.compile(r"[^a-zA-Z0-9_]").sub): + return replace_non_anum("_", request.node.name) diff --git a/tests/system/test_arrow.py b/tests/system/test_arrow.py index 12f7af9cb..f033d0b78 100644 --- a/tests/system/test_arrow.py +++ b/tests/system/test_arrow.py @@ -110,3 +110,56 @@ def test_list_rows_nullable_scalars_dtypes( timestamp_type = schema.field("timestamp_col").type assert timestamp_type.unit == "us" assert timestamp_type.tz is not None + + +@pytest.mark.parametrize("do_insert", [True, False]) +def test_arrow_extension_types_same_for_storage_and_REST_APIs_894( + dataset_client, test_table_name, do_insert +): + types = dict( + astring=("STRING", "'x'"), + astring9=("STRING(9)", "'x'"), + abytes=("BYTES", "b'x'"), + abytes9=("BYTES(9)", "b'x'"), + anumeric=("NUMERIC", "42"), + anumeric9=("NUMERIC(9)", "42"), + anumeric92=("NUMERIC(9,2)", "42"), + abignumeric=("BIGNUMERIC", "42e30"), + abignumeric49=("BIGNUMERIC(37)", "42e30"), + abignumeric492=("BIGNUMERIC(37,2)", "42e30"), + abool=("BOOL", "true"), + adate=("DATE", "'2021-09-06'"), + adatetime=("DATETIME", "'2021-09-06T09:57:26'"), + ageography=("GEOGRAPHY", "ST_GEOGFROMTEXT('point(0 0)')"), + # Can't get arrow data for interval :( + # ainterval=('INTERVAL', "make_interval(1, 2, 3, 4, 5, 6)"), + aint64=("INT64", "42"), + afloat64=("FLOAT64", "42.0"), + astruct=("STRUCT", "struct(42)"), + atime=("TIME", "'1:2:3'"), + atimestamp=("TIMESTAMP", "'2021-09-06T09:57:26'"), + ) + columns = ", ".join(f"{k} {t[0]}" for k, t in types.items()) + dataset_client.query(f"create table {test_table_name} ({columns})").result() + if do_insert: + names = list(types) + values = ", ".join(types[name][1] for name in names) + names = ", ".join(names) + dataset_client.query( + f"insert into {test_table_name} ({names}) values ({values})" + ).result() + at = dataset_client.query(f"select * from {test_table_name}").result().to_arrow() + smd = {at.field(i).name: at.field(i).metadata for i in range(at.num_columns)} + at = ( + dataset_client.query(f"select * from {test_table_name}") + .result() + .to_arrow(create_bqstorage_client=False) + ) + rmd = {at.field(i).name: at.field(i).metadata for i in range(at.num_columns)} + + assert rmd == smd + assert rmd["adatetime"] == {b"ARROW:extension:name": b"google:sqlType:datetime"} + assert rmd["ageography"] == { + b"ARROW:extension:name": b"google:sqlType:geography", + b"ARROW:extension:metadata": b'{"encoding": "WKT"}', + } diff --git a/tests/unit/test__pandas_helpers.py b/tests/unit/test__pandas_helpers.py index 80b226a3a..ef8c80c81 100644 --- a/tests/unit/test__pandas_helpers.py +++ b/tests/unit/test__pandas_helpers.py @@ -1696,3 +1696,26 @@ def test_bq_to_arrow_field_type_override(module_under_test): ).type == pyarrow.binary() ) + + +@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") +@pytest.mark.parametrize( + "field_type, metadata", + [ + ("datetime", {b"ARROW:extension:name": b"google:sqlType:datetime"}), + ( + "geography", + { + b"ARROW:extension:name": b"google:sqlType:geography", + b"ARROW:extension:metadata": b'{"encoding": "WKT"}', + }, + ), + ], +) +def test_bq_to_arrow_field_metadata(module_under_test, field_type, metadata): + assert ( + module_under_test.bq_to_arrow_field( + schema.SchemaField("g", field_type) + ).metadata + == metadata + ) From 53ed313794bb4a1eba9df20ee3208f470ea12cfb Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Tue, 7 Sep 2021 08:46:43 -0600 Subject: [PATCH 2/3] Add comment addressing explaining 2 checks --- google/cloud/bigquery/table.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index dd78b211a..c4a45dc83 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1813,7 +1813,11 @@ def to_arrow( if record_batches and bqstorage_client is not None: return pyarrow.Table.from_batches(record_batches) else: - # No records, use schema based on BigQuery schema. + # No records (not record_batches), use schema based on BigQuery schema + # **or** + # we used the REST API (bqstorage_client is None), + # which doesn't add arrow extension metadata, so we let + # `bq_to_arrow_schema` do it. arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema) return pyarrow.Table.from_batches(record_batches, schema=arrow_schema) From 683d6dfe5b01cc141b25373d3d844fd97ba04037 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Tue, 7 Sep 2021 08:53:36 -0600 Subject: [PATCH 3/3] better variable names --- tests/system/test_arrow.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/system/test_arrow.py b/tests/system/test_arrow.py index f033d0b78..96f9dea25 100644 --- a/tests/system/test_arrow.py +++ b/tests/system/test_arrow.py @@ -149,17 +149,23 @@ def test_arrow_extension_types_same_for_storage_and_REST_APIs_894( f"insert into {test_table_name} ({names}) values ({values})" ).result() at = dataset_client.query(f"select * from {test_table_name}").result().to_arrow() - smd = {at.field(i).name: at.field(i).metadata for i in range(at.num_columns)} + storage_api_metadata = { + at.field(i).name: at.field(i).metadata for i in range(at.num_columns) + } at = ( dataset_client.query(f"select * from {test_table_name}") .result() .to_arrow(create_bqstorage_client=False) ) - rmd = {at.field(i).name: at.field(i).metadata for i in range(at.num_columns)} + rest_api_metadata = { + at.field(i).name: at.field(i).metadata for i in range(at.num_columns) + } - assert rmd == smd - assert rmd["adatetime"] == {b"ARROW:extension:name": b"google:sqlType:datetime"} - assert rmd["ageography"] == { + assert rest_api_metadata == storage_api_metadata + assert rest_api_metadata["adatetime"] == { + b"ARROW:extension:name": b"google:sqlType:datetime" + } + assert rest_api_metadata["ageography"] == { b"ARROW:extension:name": b"google:sqlType:geography", b"ARROW:extension:metadata": b'{"encoding": "WKT"}', }