Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Arrow extension-type metadata was not set when calling the REST API or when there are no rows #946

Merged
merged 6 commits into from Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion google/cloud/bigquery/_pandas_helpers.py
Expand Up @@ -173,6 +173,13 @@ def pyarrow_timestamp():
pyarrow.decimal128(38, scale=9).id: "NUMERIC",
pyarrow.decimal256(76, scale=38).id: "BIGNUMERIC",
}
BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = {
"GEOGRAPHY": {
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
},
"DATETIME": {b"ARROW:extension:name": b"google:sqlType:datetime"},
}

else: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
Expand Down Expand Up @@ -227,7 +234,12 @@ def bq_to_arrow_field(bq_field, array_type=None):
if array_type is not None:
arrow_type = array_type # For GEOGRAPHY, at least initially
is_nullable = bq_field.mode.upper() == "NULLABLE"
return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable)
metadata = BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(
bq_field.field_type.upper() if bq_field.field_type else ""
plamut marked this conversation as resolved.
Show resolved Hide resolved
)
return pyarrow.field(
bq_field.name, arrow_type, nullable=is_nullable, metadata=metadata
)

warnings.warn("Unable to determine type for field '{}'.".format(bq_field.name))
return None
Expand Down
8 changes: 6 additions & 2 deletions google/cloud/bigquery/table.py
Expand Up @@ -1810,10 +1810,14 @@ def to_arrow(
if owns_bqstorage_client:
bqstorage_client._transport.grpc_channel.close()

if record_batches:
if record_batches and bqstorage_client is not None:
plamut marked this conversation as resolved.
Show resolved Hide resolved
return pyarrow.Table.from_batches(record_batches)
else:
# No records, use schema based on BigQuery schema.
# No records (not record_batches), use schema based on BigQuery schema
# **or**
# we used the REST API (bqstorage_client is None),
# which doesn't add arrow extension metadata, so we let
# `bq_to_arrow_schema` do it.
arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema)
return pyarrow.Table.from_batches(record_batches, schema=arrow_schema)

Expand Down
17 changes: 17 additions & 0 deletions tests/system/conftest.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pathlib
import re

import pytest
import test_utils.prefixer
Expand Down Expand Up @@ -61,6 +62,17 @@ def dataset_id(bigquery_client):
bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True)


@pytest.fixture()
def dataset_client(bigquery_client, dataset_id):
import google.cloud.bigquery.job

return bigquery.Client(
default_query_job_config=google.cloud.bigquery.job.QueryJobConfig(
default_dataset=f"{bigquery_client.project}.{dataset_id}",
)
)


@pytest.fixture
def table_id(dataset_id):
return f"{dataset_id}.table_{helpers.temp_suffix()}"
Expand Down Expand Up @@ -98,3 +110,8 @@ def scalars_extreme_table(
job.result()
yield full_table_id
bigquery_client.delete_table(full_table_id)


@pytest.fixture
def test_table_name(request, replace_non_anum=re.compile(r"[^a-zA-Z0-9_]").sub):
return replace_non_anum("_", request.node.name)
59 changes: 59 additions & 0 deletions tests/system/test_arrow.py
Expand Up @@ -110,3 +110,62 @@ def test_list_rows_nullable_scalars_dtypes(
timestamp_type = schema.field("timestamp_col").type
assert timestamp_type.unit == "us"
assert timestamp_type.tz is not None


@pytest.mark.parametrize("do_insert", [True, False])
def test_arrow_extension_types_same_for_storage_and_REST_APIs_894(
dataset_client, test_table_name, do_insert
):
types = dict(
astring=("STRING", "'x'"),
astring9=("STRING(9)", "'x'"),
abytes=("BYTES", "b'x'"),
abytes9=("BYTES(9)", "b'x'"),
anumeric=("NUMERIC", "42"),
anumeric9=("NUMERIC(9)", "42"),
anumeric92=("NUMERIC(9,2)", "42"),
abignumeric=("BIGNUMERIC", "42e30"),
abignumeric49=("BIGNUMERIC(37)", "42e30"),
abignumeric492=("BIGNUMERIC(37,2)", "42e30"),
abool=("BOOL", "true"),
adate=("DATE", "'2021-09-06'"),
adatetime=("DATETIME", "'2021-09-06T09:57:26'"),
ageography=("GEOGRAPHY", "ST_GEOGFROMTEXT('point(0 0)')"),
# Can't get arrow data for interval :(
# ainterval=('INTERVAL', "make_interval(1, 2, 3, 4, 5, 6)"),
aint64=("INT64", "42"),
afloat64=("FLOAT64", "42.0"),
astruct=("STRUCT<v int64>", "struct(42)"),
atime=("TIME", "'1:2:3'"),
atimestamp=("TIMESTAMP", "'2021-09-06T09:57:26'"),
)
columns = ", ".join(f"{k} {t[0]}" for k, t in types.items())
dataset_client.query(f"create table {test_table_name} ({columns})").result()
if do_insert:
names = list(types)
values = ", ".join(types[name][1] for name in names)
names = ", ".join(names)
dataset_client.query(
f"insert into {test_table_name} ({names}) values ({values})"
).result()
at = dataset_client.query(f"select * from {test_table_name}").result().to_arrow()
storage_api_metadata = {
at.field(i).name: at.field(i).metadata for i in range(at.num_columns)
}
at = (
dataset_client.query(f"select * from {test_table_name}")
.result()
.to_arrow(create_bqstorage_client=False)
)
rest_api_metadata = {
at.field(i).name: at.field(i).metadata for i in range(at.num_columns)
}

assert rest_api_metadata == storage_api_metadata
assert rest_api_metadata["adatetime"] == {
b"ARROW:extension:name": b"google:sqlType:datetime"
}
assert rest_api_metadata["ageography"] == {
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
}
23 changes: 23 additions & 0 deletions tests/unit/test__pandas_helpers.py
Expand Up @@ -1696,3 +1696,26 @@ def test_bq_to_arrow_field_type_override(module_under_test):
).type
== pyarrow.binary()
)


@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
@pytest.mark.parametrize(
"field_type, metadata",
[
("datetime", {b"ARROW:extension:name": b"google:sqlType:datetime"}),
(
"geography",
{
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
},
),
],
)
def test_bq_to_arrow_field_metadata(module_under_test, field_type, metadata):
assert (
module_under_test.bq_to_arrow_field(
schema.SchemaField("g", field_type)
).metadata
== metadata
)