Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Arrow extension-type metadata was not set when calling the REST API or when there are no rows #946

Merged
merged 6 commits into from Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion google/cloud/bigquery/_pandas_helpers.py
Expand Up @@ -173,6 +173,13 @@ def pyarrow_timestamp():
pyarrow.decimal128(38, scale=9).id: "NUMERIC",
pyarrow.decimal256(76, scale=38).id: "BIGNUMERIC",
}
BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = {
"GEOGRAPHY": {
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
},
"DATETIME": {b"ARROW:extension:name": b"google:sqlType:datetime"},
}

else: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
Expand Down Expand Up @@ -227,7 +234,12 @@ def bq_to_arrow_field(bq_field, array_type=None):
if array_type is not None:
arrow_type = array_type # For GEOGRAPHY, at least initially
is_nullable = bq_field.mode.upper() == "NULLABLE"
return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable)
metadata = BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(
bq_field.field_type.upper() if bq_field.field_type else ""
plamut marked this conversation as resolved.
Show resolved Hide resolved
)
return pyarrow.field(
bq_field.name, arrow_type, nullable=is_nullable, metadata=metadata
)

warnings.warn("Unable to determine type for field '{}'.".format(bq_field.name))
return None
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/bigquery/table.py
Expand Up @@ -1810,7 +1810,7 @@ def to_arrow(
if owns_bqstorage_client:
bqstorage_client._transport.grpc_channel.close()

if record_batches:
if record_batches and bqstorage_client is not None:
plamut marked this conversation as resolved.
Show resolved Hide resolved
return pyarrow.Table.from_batches(record_batches)
else:
# No records, use schema based on BigQuery schema.
Expand Down
17 changes: 17 additions & 0 deletions tests/system/conftest.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pathlib
import re

import pytest
import test_utils.prefixer
Expand Down Expand Up @@ -61,6 +62,17 @@ def dataset_id(bigquery_client):
bigquery_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True)


@pytest.fixture()
def dataset_client(bigquery_client, dataset_id):
import google.cloud.bigquery.job

return bigquery.Client(
default_query_job_config=google.cloud.bigquery.job.QueryJobConfig(
default_dataset=f"{bigquery_client.project}.{dataset_id}",
)
)


@pytest.fixture
def table_id(dataset_id):
return f"{dataset_id}.table_{helpers.temp_suffix()}"
Expand Down Expand Up @@ -98,3 +110,8 @@ def scalars_extreme_table(
job.result()
yield full_table_id
bigquery_client.delete_table(full_table_id)


@pytest.fixture
def test_table_name(request, replace_non_anum=re.compile(r"[^a-zA-Z0-9_]").sub):
return replace_non_anum("_", request.node.name)
53 changes: 53 additions & 0 deletions tests/system/test_arrow.py
Expand Up @@ -110,3 +110,56 @@ def test_list_rows_nullable_scalars_dtypes(
timestamp_type = schema.field("timestamp_col").type
assert timestamp_type.unit == "us"
assert timestamp_type.tz is not None


@pytest.mark.parametrize("do_insert", [True, False])
def test_arrow_extension_types_same_for_storage_and_REST_APIs_894(
dataset_client, test_table_name, do_insert
):
types = dict(
astring=("STRING", "'x'"),
astring9=("STRING(9)", "'x'"),
abytes=("BYTES", "b'x'"),
abytes9=("BYTES(9)", "b'x'"),
anumeric=("NUMERIC", "42"),
anumeric9=("NUMERIC(9)", "42"),
anumeric92=("NUMERIC(9,2)", "42"),
abignumeric=("BIGNUMERIC", "42e30"),
abignumeric49=("BIGNUMERIC(37)", "42e30"),
abignumeric492=("BIGNUMERIC(37,2)", "42e30"),
abool=("BOOL", "true"),
adate=("DATE", "'2021-09-06'"),
adatetime=("DATETIME", "'2021-09-06T09:57:26'"),
ageography=("GEOGRAPHY", "ST_GEOGFROMTEXT('point(0 0)')"),
# Can't get arrow data for interval :(
# ainterval=('INTERVAL', "make_interval(1, 2, 3, 4, 5, 6)"),
aint64=("INT64", "42"),
afloat64=("FLOAT64", "42.0"),
astruct=("STRUCT<v int64>", "struct(42)"),
atime=("TIME", "'1:2:3'"),
atimestamp=("TIMESTAMP", "'2021-09-06T09:57:26'"),
)
columns = ", ".join(f"{k} {t[0]}" for k, t in types.items())
dataset_client.query(f"create table {test_table_name} ({columns})").result()
if do_insert:
names = list(types)
values = ", ".join(types[name][1] for name in names)
names = ", ".join(names)
dataset_client.query(
f"insert into {test_table_name} ({names}) values ({values})"
).result()
at = dataset_client.query(f"select * from {test_table_name}").result().to_arrow()
smd = {at.field(i).name: at.field(i).metadata for i in range(at.num_columns)}
plamut marked this conversation as resolved.
Show resolved Hide resolved
at = (
dataset_client.query(f"select * from {test_table_name}")
.result()
.to_arrow(create_bqstorage_client=False)
)
rmd = {at.field(i).name: at.field(i).metadata for i in range(at.num_columns)}

assert rmd == smd
assert rmd["adatetime"] == {b"ARROW:extension:name": b"google:sqlType:datetime"}
assert rmd["ageography"] == {
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
}
23 changes: 23 additions & 0 deletions tests/unit/test__pandas_helpers.py
Expand Up @@ -1696,3 +1696,26 @@ def test_bq_to_arrow_field_type_override(module_under_test):
).type
== pyarrow.binary()
)


@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
@pytest.mark.parametrize(
"field_type, metadata",
[
("datetime", {b"ARROW:extension:name": b"google:sqlType:datetime"}),
(
"geography",
{
b"ARROW:extension:name": b"google:sqlType:geography",
b"ARROW:extension:metadata": b'{"encoding": "WKT"}',
},
),
],
)
def test_bq_to_arrow_field_metadata(module_under_test, field_type, metadata):
assert (
module_under_test.bq_to_arrow_field(
schema.SchemaField("g", field_type)
).metadata
== metadata
)