diff --git a/google/cloud/bigquery/_helpers.py b/google/cloud/bigquery/_helpers.py index 28a76206e..d7189d322 100644 --- a/google/cloud/bigquery/_helpers.py +++ b/google/cloud/bigquery/_helpers.py @@ -107,6 +107,9 @@ def verify_version(self): class PyarrowVersions: """Version comparisons for pyarrow package.""" + # https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414 + _PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")]) + def __init__(self): self._installed_version = None @@ -126,6 +129,14 @@ def installed_version(self) -> packaging.version.Version: return self._installed_version + @property + def is_bad_version(self) -> bool: + return self.installed_version in self._PYARROW_BAD_VERSIONS + + @property + def use_compliant_nested_type(self) -> bool: + return self.installed_version.major >= 4 + def try_import(self, raise_if_error: bool = False) -> Any: """Verify that a recent enough version of pyarrow extra is installed. diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 869c0215d..0cb851469 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -79,8 +79,8 @@ def _to_wkb(v): _PANDAS_DTYPE_TO_BQ = { "bool": "BOOLEAN", "datetime64[ns, UTC]": "TIMESTAMP", - # BigQuery does not support uploading DATETIME values from Parquet files. - # See: https://github.com/googleapis/google-cloud-python/issues/9996 + # TODO: Update to DATETIME in V3 + # https://github.com/googleapis/python-bigquery/issues/985 "datetime64[ns]": "TIMESTAMP", "float32": "FLOAT", "float64": "FLOAT", @@ -396,7 +396,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema): # column, but it was not found. if bq_schema_unused: raise ValueError( - u"bq_schema contains fields not present in dataframe: {}".format( + "bq_schema contains fields not present in dataframe: {}".format( bq_schema_unused ) ) @@ -405,7 +405,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema): # pyarrow, if available. if unknown_type_fields: if not pyarrow: - msg = u"Could not determine the type of columns: {}".format( + msg = "Could not determine the type of columns: {}".format( ", ".join(field.name for field in unknown_type_fields) ) warnings.warn(msg) @@ -444,7 +444,14 @@ def augment_schema(dataframe, current_bq_schema): continue arrow_table = pyarrow.array(dataframe[field.name]) - detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id) + + if pyarrow.types.is_list(arrow_table.type): + # `pyarrow.ListType` + detected_mode = "REPEATED" + detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.values.type.id) + else: + detected_mode = field.mode + detected_type = ARROW_SCALAR_IDS_TO_BQ.get(arrow_table.type.id) if detected_type is None: unknown_type_fields.append(field) @@ -453,7 +460,7 @@ def augment_schema(dataframe, current_bq_schema): new_field = schema.SchemaField( name=field.name, field_type=detected_type, - mode=field.mode, + mode=detected_mode, description=field.description, fields=field.fields, ) @@ -461,7 +468,7 @@ def augment_schema(dataframe, current_bq_schema): if unknown_type_fields: warnings.warn( - u"Pyarrow could not determine the type of columns: {}.".format( + "Pyarrow could not determine the type of columns: {}.".format( ", ".join(field.name for field in unknown_type_fields) ) ) @@ -500,7 +507,7 @@ def dataframe_to_arrow(dataframe, bq_schema): extra_fields = bq_field_names - column_and_index_names if extra_fields: raise ValueError( - u"bq_schema contains fields not present in dataframe: {}".format( + "bq_schema contains fields not present in dataframe: {}".format( extra_fields ) ) @@ -510,7 +517,7 @@ def dataframe_to_arrow(dataframe, bq_schema): missing_fields = column_names - bq_field_names if missing_fields: raise ValueError( - u"bq_schema is missing fields from dataframe: {}".format(missing_fields) + "bq_schema is missing fields from dataframe: {}".format(missing_fields) ) arrow_arrays = [] @@ -530,7 +537,13 @@ def dataframe_to_arrow(dataframe, bq_schema): return pyarrow.Table.from_arrays(arrow_arrays, names=arrow_names) -def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SNAPPY"): +def dataframe_to_parquet( + dataframe, + bq_schema, + filepath, + parquet_compression="SNAPPY", + parquet_use_compliant_nested_type=True, +): """Write dataframe as a Parquet file, according to the desired BQ schema. This function requires the :mod:`pyarrow` package. Arrow is used as an @@ -551,14 +564,29 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN The compression codec to use by the the ``pyarrow.parquet.write_table`` serializing method. Defaults to "SNAPPY". https://arrow.apache.org/docs/python/generated/pyarrow.parquet.write_table.html#pyarrow-parquet-write-table + parquet_use_compliant_nested_type (bool): + Whether the ``pyarrow.parquet.write_table`` serializing method should write + compliant Parquet nested type (lists). Defaults to ``True``. + https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#nested-types + https://arrow.apache.org/docs/python/generated/pyarrow.parquet.write_table.html#pyarrow-parquet-write-table + + This argument is ignored for ``pyarrow`` versions earlier than ``4.0.0``. """ pyarrow = _helpers.PYARROW_VERSIONS.try_import(raise_if_error=True) import pyarrow.parquet + kwargs = ( + {"use_compliant_nested_type": parquet_use_compliant_nested_type} + if _helpers.PYARROW_VERSIONS.use_compliant_nested_type + else {} + ) + bq_schema = schema._to_schema_fields(bq_schema) arrow_table = dataframe_to_arrow(dataframe, bq_schema) - pyarrow.parquet.write_table(arrow_table, filepath, compression=parquet_compression) + pyarrow.parquet.write_table( + arrow_table, filepath, compression=parquet_compression, **kwargs, + ) def _row_iterator_page_to_arrow(page, column_names, arrow_types): diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 47ff83c5d..a8a1c1e16 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -27,19 +27,11 @@ import json import math import os -import packaging.version import tempfile from typing import Any, BinaryIO, Dict, Iterable, Optional, Sequence, Tuple, Union import uuid import warnings -try: - import pyarrow - - _PYARROW_VERSION = packaging.version.parse(pyarrow.__version__) -except ImportError: # pragma: NO COVER - pyarrow = None - from google import resumable_media # type: ignore from google.resumable_media.requests import MultipartUpload from google.resumable_media.requests import ResumableUpload @@ -103,6 +95,10 @@ from google.cloud.bigquery.table import TableListItem from google.cloud.bigquery.table import TableReference from google.cloud.bigquery.table import RowIterator +from google.cloud.bigquery.format_options import ParquetOptions +from google.cloud.bigquery import _helpers + +pyarrow = _helpers.PYARROW_VERSIONS.try_import() _DEFAULT_CHUNKSIZE = 100 * 1024 * 1024 # 100 MB @@ -128,8 +124,6 @@ # https://github.com/googleapis/python-bigquery/issues/438 _MIN_GET_QUERY_RESULTS_TIMEOUT = 120 -# https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414 -_PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")]) TIMEOUT_HEADER = "X-Server-Timeout" @@ -2469,10 +2463,10 @@ def load_table_from_dataframe( They are supported when using the PARQUET source format, but due to the way they are encoded in the ``parquet`` file, a mismatch with the existing table schema can occur, so - 100% compatibility cannot be guaranteed for REPEATED fields when + REPEATED fields are not properly supported when using ``pyarrow<4.0.0`` using the parquet format. - https://github.com/googleapis/python-bigquery/issues/17 + https://github.com/googleapis/python-bigquery/issues/19 Args: dataframe (pandas.DataFrame): @@ -2519,18 +2513,18 @@ def load_table_from_dataframe( :attr:`~google.cloud.bigquery.job.SourceFormat.PARQUET` are supported. parquet_compression (Optional[str]): - [Beta] The compression method to use if intermittently - serializing ``dataframe`` to a parquet file. - - The argument is directly passed as the ``compression`` - argument to the underlying ``pyarrow.parquet.write_table()`` - method (the default value "snappy" gets converted to uppercase). - https://arrow.apache.org/docs/python/generated/pyarrow.parquet.write_table.html#pyarrow-parquet-write-table - - If the job config schema is missing, the argument is directly - passed as the ``compression`` argument to the underlying - ``DataFrame.to_parquet()`` method. - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_parquet.html#pandas.DataFrame.to_parquet + [Beta] The compression method to use if intermittently + serializing ``dataframe`` to a parquet file. + + The argument is directly passed as the ``compression`` + argument to the underlying ``pyarrow.parquet.write_table()`` + method (the default value "snappy" gets converted to uppercase). + https://arrow.apache.org/docs/python/generated/pyarrow.parquet.write_table.html#pyarrow-parquet-write-table + + If the job config schema is missing, the argument is directly + passed as the ``compression`` argument to the underlying + ``DataFrame.to_parquet()`` method. + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_parquet.html#pandas.DataFrame.to_parquet timeout (Optional[float]): The number of seconds to wait for the underlying HTTP transport before using ``retry``. @@ -2562,6 +2556,16 @@ def load_table_from_dataframe( if job_config.source_format is None: # default value job_config.source_format = job.SourceFormat.PARQUET + + if ( + job_config.source_format == job.SourceFormat.PARQUET + and job_config.parquet_options is None + ): + parquet_options = ParquetOptions() + # default value + parquet_options.enable_list_inference = True + job_config.parquet_options = parquet_options + if job_config.source_format not in supported_formats: raise ValueError( "Got unexpected source_format: '{}'. Currently, only PARQUET and CSV are supported".format( @@ -2628,12 +2632,12 @@ def load_table_from_dataframe( try: if job_config.source_format == job.SourceFormat.PARQUET: - if _PYARROW_VERSION in _PYARROW_BAD_VERSIONS: + if _helpers.PYARROW_VERSIONS.is_bad_version: msg = ( "Loading dataframe data in PARQUET format with pyarrow " - f"{_PYARROW_VERSION} can result in data corruption. It is " - "therefore *strongly* advised to use a different pyarrow " - "version or a different source format. " + f"{_helpers.PYARROW_VERSIONS.installed_version} can result in data " + "corruption. It is therefore *strongly* advised to use a " + "different pyarrow version or a different source format. " "See: https://github.com/googleapis/python-bigquery/issues/781" ) warnings.warn(msg, category=RuntimeWarning) @@ -2647,9 +2651,19 @@ def load_table_from_dataframe( job_config.schema, tmppath, parquet_compression=parquet_compression, + parquet_use_compliant_nested_type=True, ) else: - dataframe.to_parquet(tmppath, compression=parquet_compression) + dataframe.to_parquet( + tmppath, + engine="pyarrow", + compression=parquet_compression, + **( + {"use_compliant_nested_type": True} + if _helpers.PYARROW_VERSIONS.use_compliant_nested_type + else {} + ), + ) else: diff --git a/tests/system/test_pandas.py b/tests/system/test_pandas.py index 93ce23481..1f43a369a 100644 --- a/tests/system/test_pandas.py +++ b/tests/system/test_pandas.py @@ -24,6 +24,7 @@ import google.api_core.retry import pkg_resources import pytest +import numpy from google.cloud import bigquery from . import helpers @@ -84,6 +85,81 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i ("uint8_col", pandas.Series([0, 1, 2], dtype="uint8")), ("uint16_col", pandas.Series([3, 4, 5], dtype="uint16")), ("uint32_col", pandas.Series([6, 7, 8], dtype="uint32")), + ("array_bool_col", pandas.Series([[True], [False], [True]])), + ( + "array_ts_col", + pandas.Series( + [ + [ + datetime.datetime( + 2010, 1, 2, 3, 44, 50, tzinfo=datetime.timezone.utc + ), + ], + [ + datetime.datetime( + 2011, 2, 3, 14, 50, 59, tzinfo=datetime.timezone.utc + ), + ], + [ + datetime.datetime( + 2012, 3, 14, 15, 16, tzinfo=datetime.timezone.utc + ), + ], + ], + ), + ), + ( + "array_dt_col", + pandas.Series( + [ + [datetime.datetime(2010, 1, 2, 3, 44, 50)], + [datetime.datetime(2011, 2, 3, 14, 50, 59)], + [datetime.datetime(2012, 3, 14, 15, 16)], + ], + ), + ), + ( + "array_float32_col", + pandas.Series( + [numpy.array([_], dtype="float32") for _ in [1.0, 2.0, 3.0]] + ), + ), + ( + "array_float64_col", + pandas.Series( + [numpy.array([_], dtype="float64") for _ in [4.0, 5.0, 6.0]] + ), + ), + ( + "array_int8_col", + pandas.Series( + [numpy.array([_], dtype="int8") for _ in [-12, -11, -10]] + ), + ), + ( + "array_int16_col", + pandas.Series([numpy.array([_], dtype="int16") for _ in [-9, -8, -7]]), + ), + ( + "array_int32_col", + pandas.Series([numpy.array([_], dtype="int32") for _ in [-6, -5, -4]]), + ), + ( + "array_int64_col", + pandas.Series([numpy.array([_], dtype="int64") for _ in [-3, -2, -1]]), + ), + ( + "array_uint8_col", + pandas.Series([numpy.array([_], dtype="uint8") for _ in [0, 1, 2]]), + ), + ( + "array_uint16_col", + pandas.Series([numpy.array([_], dtype="uint16") for _ in [3, 4, 5]]), + ), + ( + "array_uint32_col", + pandas.Series([numpy.array([_], dtype="uint32") for _ in [6, 7, 8]]), + ), ] ) dataframe = pandas.DataFrame(df_data, columns=df_data.keys()) @@ -99,9 +175,8 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i assert tuple(table.schema) == ( bigquery.SchemaField("bool_col", "BOOLEAN"), bigquery.SchemaField("ts_col", "TIMESTAMP"), - # BigQuery does not support uploading DATETIME values from - # Parquet files. See: - # https://github.com/googleapis/google-cloud-python/issues/9996 + # TODO: Update to DATETIME in V3 + # https://github.com/googleapis/python-bigquery/issues/985 bigquery.SchemaField("dt_col", "TIMESTAMP"), bigquery.SchemaField("float32_col", "FLOAT"), bigquery.SchemaField("float64_col", "FLOAT"), @@ -112,6 +187,20 @@ def test_load_table_from_dataframe_w_automatic_schema(bigquery_client, dataset_i bigquery.SchemaField("uint8_col", "INTEGER"), bigquery.SchemaField("uint16_col", "INTEGER"), bigquery.SchemaField("uint32_col", "INTEGER"), + bigquery.SchemaField("array_bool_col", "BOOLEAN", mode="REPEATED"), + bigquery.SchemaField("array_ts_col", "TIMESTAMP", mode="REPEATED"), + # TODO: Update to DATETIME in V3 + # https://github.com/googleapis/python-bigquery/issues/985 + bigquery.SchemaField("array_dt_col", "TIMESTAMP", mode="REPEATED"), + bigquery.SchemaField("array_float32_col", "FLOAT", mode="REPEATED"), + bigquery.SchemaField("array_float64_col", "FLOAT", mode="REPEATED"), + bigquery.SchemaField("array_int8_col", "INTEGER", mode="REPEATED"), + bigquery.SchemaField("array_int16_col", "INTEGER", mode="REPEATED"), + bigquery.SchemaField("array_int32_col", "INTEGER", mode="REPEATED"), + bigquery.SchemaField("array_int64_col", "INTEGER", mode="REPEATED"), + bigquery.SchemaField("array_uint8_col", "INTEGER", mode="REPEATED"), + bigquery.SchemaField("array_uint16_col", "INTEGER", mode="REPEATED"), + bigquery.SchemaField("array_uint32_col", "INTEGER", mode="REPEATED"), ) assert table.num_rows == 3 diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index eb70470b5..48dacf7e2 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -57,6 +57,7 @@ from google.cloud import bigquery_v2 from google.cloud.bigquery.dataset import DatasetReference from google.cloud.bigquery.retry import DEFAULT_TIMEOUT +from google.cloud.bigquery import ParquetOptions try: from google.cloud import bigquery_storage @@ -6942,6 +6943,179 @@ def test_load_table_from_dataframe_w_custom_job_config_w_source_format(self): # the original config object should not have been modified assert job_config.to_api_repr() == original_config_copy.to_api_repr() + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_parquet_options_none(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] + dataframe = pandas.DataFrame(records) + + job_config = job.LoadJobConfig( + write_disposition=job.WriteDisposition.WRITE_TRUNCATE, + source_format=job.SourceFormat.PARQUET, + ) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")] + ), + ) + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + with load_patch as load_table_from_file, get_table_patch as get_table: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION + ) + + # no need to fetch and inspect table schema for WRITE_TRUNCATE jobs + assert not get_table.called + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + size=mock.ANY, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + timeout=DEFAULT_TIMEOUT, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.parquet_options.enable_list_inference is True + + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_list_inference_none(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] + dataframe = pandas.DataFrame(records) + + parquet_options = ParquetOptions() + + job_config = job.LoadJobConfig( + write_disposition=job.WriteDisposition.WRITE_TRUNCATE, + source_format=job.SourceFormat.PARQUET, + ) + job_config.parquet_options = parquet_options + + original_config_copy = copy.deepcopy(job_config) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")] + ), + ) + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + with load_patch as load_table_from_file, get_table_patch as get_table: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION + ) + + # no need to fetch and inspect table schema for WRITE_TRUNCATE jobs + assert not get_table.called + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + size=mock.ANY, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + timeout=DEFAULT_TIMEOUT, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.parquet_options.enable_list_inference is None + + # the original config object should not have been modified + assert job_config.to_api_repr() == original_config_copy.to_api_repr() + + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_list_inference_false(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] + dataframe = pandas.DataFrame(records) + + parquet_options = ParquetOptions() + parquet_options.enable_list_inference = False + + job_config = job.LoadJobConfig( + write_disposition=job.WriteDisposition.WRITE_TRUNCATE, + source_format=job.SourceFormat.PARQUET, + ) + job_config.parquet_options = parquet_options + + original_config_copy = copy.deepcopy(job_config) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")] + ), + ) + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + with load_patch as load_table_from_file, get_table_patch as get_table: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION + ) + + # no need to fetch and inspect table schema for WRITE_TRUNCATE jobs + assert not get_table.called + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + size=mock.ANY, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + timeout=DEFAULT_TIMEOUT, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.parquet_options.enable_list_inference is False + + # the original config object should not have been modified + assert job_config.to_api_repr() == original_config_copy.to_api_repr() + @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_custom_job_config_w_wrong_source_format(self): @@ -7293,6 +7467,124 @@ def test_load_table_from_dataframe_struct_fields(self): assert sent_config.source_format == job.SourceFormat.PARQUET assert sent_config.schema == schema + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_array_fields(self): + """Test that a DataFrame with array columns can be uploaded correctly. + + See: https://github.com/googleapis/python-bigquery/issues/19 + """ + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + + records = [(3.14, [1, 2])] + dataframe = pandas.DataFrame( + data=records, columns=["float_column", "array_column"] + ) + + schema = [ + SchemaField("float_column", "FLOAT"), + SchemaField("array_column", "INTEGER", mode="REPEATED",), + ] + job_config = job.LoadJobConfig(schema=schema) + + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + side_effect=google.api_core.exceptions.NotFound("Table not found"), + ) + + with load_patch as load_table_from_file, get_table_patch: + client.load_table_from_dataframe( + dataframe, + self.TABLE_REF, + job_config=job_config, + location=self.LOCATION, + ) + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + size=mock.ANY, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + timeout=DEFAULT_TIMEOUT, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.source_format == job.SourceFormat.PARQUET + assert sent_config.schema == schema + + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_array_fields_w_auto_schema(self): + """Test that a DataFrame with array columns can be uploaded correctly. + + See: https://github.com/googleapis/python-bigquery/issues/19 + """ + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + + records = [(3.14, [1, 2])] + dataframe = pandas.DataFrame( + data=records, columns=["float_column", "array_column"] + ) + + expected_schema = [ + SchemaField("float_column", "FLOAT"), + SchemaField("array_column", "INT64", mode="REPEATED",), + ] + + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + side_effect=google.api_core.exceptions.NotFound("Table not found"), + ) + + with load_patch as load_table_from_file, get_table_patch: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, location=self.LOCATION, + ) + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + size=mock.ANY, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + timeout=DEFAULT_TIMEOUT, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.source_format == job.SourceFormat.PARQUET + assert sent_config.schema == expected_schema + @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_partial_schema(self): @@ -7540,9 +7832,13 @@ def test_load_table_from_dataframe_w_bad_pyarrow_issues_warning(self): records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] dataframe = pandas.DataFrame(records) + _helpers_mock = mock.MagicMock() + _helpers_mock.PYARROW_VERSIONS = mock.MagicMock() + _helpers_mock.PYARROW_VERSIONS.installed_version = packaging.version.parse( + "2.0.0" + ) # A known bad version of pyarrow. pyarrow_version_patch = mock.patch( - "google.cloud.bigquery.client._PYARROW_VERSION", - packaging.version.parse("2.0.0"), # A known bad version of pyarrow. + "google.cloud.bigquery.client._helpers", _helpers_mock ) get_table_patch = mock.patch( "google.cloud.bigquery.client.Client.get_table",