diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 0a22043a3..869c0215d 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -173,6 +173,13 @@ def pyarrow_timestamp(): pyarrow.decimal128(38, scale=9).id: "NUMERIC", pyarrow.decimal256(76, scale=38).id: "BIGNUMERIC", } + BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = { + "GEOGRAPHY": { + b"ARROW:extension:name": b"google:sqlType:geography", + b"ARROW:extension:metadata": b'{"encoding": "WKT"}', + }, + "DATETIME": {b"ARROW:extension:name": b"google:sqlType:datetime"}, + } else: # pragma: NO COVER BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER @@ -227,7 +234,12 @@ def bq_to_arrow_field(bq_field, array_type=None): if array_type is not None: arrow_type = array_type # For GEOGRAPHY, at least initially is_nullable = bq_field.mode.upper() == "NULLABLE" - return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable) + metadata = BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get( + bq_field.field_type.upper() if bq_field.field_type else "" + ) + return pyarrow.field( + bq_field.name, arrow_type, nullable=is_nullable, metadata=metadata + ) warnings.warn("Unable to determine type for field '{}'.".format(bq_field.name)) return None diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 609c0b57e..c4a45dc83 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1810,10 +1810,14 @@ def to_arrow( if owns_bqstorage_client: bqstorage_client._transport.grpc_channel.close() - if record_batches: + if record_batches and bqstorage_client is not None: return pyarrow.Table.from_batches(record_batches) else: - # No records, use schema based on BigQuery schema. + # No records (not record_batches), use schema based on BigQuery schema + # **or** + # we used the REST API (bqstorage_client is None), + # which doesn't add arrow extension metadata, so we let + # `bq_to_arrow_schema` do it. arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema) return pyarrow.Table.from_batches(record_batches, schema=arrow_schema) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index cc2c2a4dc..7eec76a32 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. import pathlib +import re import pytest import test_utils.prefixer @@ -61,6 +62,17 @@ def dataset_id(bigquery_client): bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True) +@pytest.fixture() +def dataset_client(bigquery_client, dataset_id): + import google.cloud.bigquery.job + + return bigquery.Client( + default_query_job_config=google.cloud.bigquery.job.QueryJobConfig( + default_dataset=f"{bigquery_client.project}.{dataset_id}", + ) + ) + + @pytest.fixture def table_id(dataset_id): return f"{dataset_id}.table_{helpers.temp_suffix()}" @@ -98,3 +110,8 @@ def scalars_extreme_table( job.result() yield full_table_id bigquery_client.delete_table(full_table_id) + + +@pytest.fixture +def test_table_name(request, replace_non_anum=re.compile(r"[^a-zA-Z0-9_]").sub): + return replace_non_anum("_", request.node.name) diff --git a/tests/system/test_arrow.py b/tests/system/test_arrow.py index 12f7af9cb..96f9dea25 100644 --- a/tests/system/test_arrow.py +++ b/tests/system/test_arrow.py @@ -110,3 +110,62 @@ def test_list_rows_nullable_scalars_dtypes( timestamp_type = schema.field("timestamp_col").type assert timestamp_type.unit == "us" assert timestamp_type.tz is not None + + +@pytest.mark.parametrize("do_insert", [True, False]) +def test_arrow_extension_types_same_for_storage_and_REST_APIs_894( + dataset_client, test_table_name, do_insert +): + types = dict( + astring=("STRING", "'x'"), + astring9=("STRING(9)", "'x'"), + abytes=("BYTES", "b'x'"), + abytes9=("BYTES(9)", "b'x'"), + anumeric=("NUMERIC", "42"), + anumeric9=("NUMERIC(9)", "42"), + anumeric92=("NUMERIC(9,2)", "42"), + abignumeric=("BIGNUMERIC", "42e30"), + abignumeric49=("BIGNUMERIC(37)", "42e30"), + abignumeric492=("BIGNUMERIC(37,2)", "42e30"), + abool=("BOOL", "true"), + adate=("DATE", "'2021-09-06'"), + adatetime=("DATETIME", "'2021-09-06T09:57:26'"), + ageography=("GEOGRAPHY", "ST_GEOGFROMTEXT('point(0 0)')"), + # Can't get arrow data for interval :( + # ainterval=('INTERVAL', "make_interval(1, 2, 3, 4, 5, 6)"), + aint64=("INT64", "42"), + afloat64=("FLOAT64", "42.0"), + astruct=("STRUCT", "struct(42)"), + atime=("TIME", "'1:2:3'"), + atimestamp=("TIMESTAMP", "'2021-09-06T09:57:26'"), + ) + columns = ", ".join(f"{k} {t[0]}" for k, t in types.items()) + dataset_client.query(f"create table {test_table_name} ({columns})").result() + if do_insert: + names = list(types) + values = ", ".join(types[name][1] for name in names) + names = ", ".join(names) + dataset_client.query( + f"insert into {test_table_name} ({names}) values ({values})" + ).result() + at = dataset_client.query(f"select * from {test_table_name}").result().to_arrow() + storage_api_metadata = { + at.field(i).name: at.field(i).metadata for i in range(at.num_columns) + } + at = ( + dataset_client.query(f"select * from {test_table_name}") + .result() + .to_arrow(create_bqstorage_client=False) + ) + rest_api_metadata = { + at.field(i).name: at.field(i).metadata for i in range(at.num_columns) + } + + assert rest_api_metadata == storage_api_metadata + assert rest_api_metadata["adatetime"] == { + b"ARROW:extension:name": b"google:sqlType:datetime" + } + assert rest_api_metadata["ageography"] == { + b"ARROW:extension:name": b"google:sqlType:geography", + b"ARROW:extension:metadata": b'{"encoding": "WKT"}', + } diff --git a/tests/unit/test__pandas_helpers.py b/tests/unit/test__pandas_helpers.py index 80b226a3a..ef8c80c81 100644 --- a/tests/unit/test__pandas_helpers.py +++ b/tests/unit/test__pandas_helpers.py @@ -1696,3 +1696,26 @@ def test_bq_to_arrow_field_type_override(module_under_test): ).type == pyarrow.binary() ) + + +@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") +@pytest.mark.parametrize( + "field_type, metadata", + [ + ("datetime", {b"ARROW:extension:name": b"google:sqlType:datetime"}), + ( + "geography", + { + b"ARROW:extension:name": b"google:sqlType:geography", + b"ARROW:extension:metadata": b'{"encoding": "WKT"}', + }, + ), + ], +) +def test_bq_to_arrow_field_metadata(module_under_test, field_type, metadata): + assert ( + module_under_test.bq_to_arrow_field( + schema.SchemaField("g", field_type) + ).metadata + == metadata + )