diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 57c8f95f6..7774ce26b 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -474,7 +474,7 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN pyarrow.parquet.write_table(arrow_table, filepath, compression=parquet_compression) -def _tabledata_list_page_to_arrow(page, column_names, arrow_types): +def _row_iterator_page_to_arrow(page, column_names, arrow_types): # Iterate over the page to force the API request to get the page data. try: next(iter(page)) @@ -490,8 +490,8 @@ def _tabledata_list_page_to_arrow(page, column_names, arrow_types): return pyarrow.RecordBatch.from_arrays(arrays, names=column_names) -def download_arrow_tabledata_list(pages, bq_schema): - """Use tabledata.list to construct an iterable of RecordBatches. +def download_arrow_row_iterator(pages, bq_schema): + """Use HTTP JSON RowIterator to construct an iterable of RecordBatches. Args: pages (Iterator[:class:`google.api_core.page_iterator.Page`]): @@ -510,10 +510,10 @@ def download_arrow_tabledata_list(pages, bq_schema): arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema] for page in pages: - yield _tabledata_list_page_to_arrow(page, column_names, arrow_types) + yield _row_iterator_page_to_arrow(page, column_names, arrow_types) -def _tabledata_list_page_to_dataframe(page, column_names, dtypes): +def _row_iterator_page_to_dataframe(page, column_names, dtypes): # Iterate over the page to force the API request to get the page data. try: next(iter(page)) @@ -528,8 +528,8 @@ def _tabledata_list_page_to_dataframe(page, column_names, dtypes): return pandas.DataFrame(columns, columns=column_names) -def download_dataframe_tabledata_list(pages, bq_schema, dtypes): - """Use (slower, but free) tabledata.list to construct a DataFrame. +def download_dataframe_row_iterator(pages, bq_schema, dtypes): + """Use HTTP JSON RowIterator to construct a DataFrame. Args: pages (Iterator[:class:`google.api_core.page_iterator.Page`]): @@ -549,7 +549,7 @@ def download_dataframe_tabledata_list(pages, bq_schema, dtypes): bq_schema = schema._to_schema_fields(bq_schema) column_names = [field.name for field in bq_schema] for page in pages: - yield _tabledata_list_page_to_dataframe(page, column_names, dtypes) + yield _row_iterator_page_to_dataframe(page, column_names, dtypes) def _bqstorage_page_to_arrow(page): diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 57df9455e..cd1474336 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -80,18 +80,19 @@ _MAX_MULTIPART_SIZE = 5 * 1024 * 1024 _DEFAULT_NUM_RETRIES = 6 _BASE_UPLOAD_TEMPLATE = ( - u"https://bigquery.googleapis.com/upload/bigquery/v2/projects/" - u"{project}/jobs?uploadType=" + "https://bigquery.googleapis.com/upload/bigquery/v2/projects/" + "{project}/jobs?uploadType=" ) -_MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + u"multipart" -_RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + u"resumable" -_GENERIC_CONTENT_TYPE = u"*/*" +_MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "multipart" +_RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "resumable" +_GENERIC_CONTENT_TYPE = "*/*" _READ_LESS_THAN_SIZE = ( "Size {:d} was specified but the file-like object only had " "{:d} bytes remaining." ) _NEED_TABLE_ARGUMENT = ( "The table argument should be a table ID string, Table, or TableReference" ) +_LIST_ROWS_FROM_QUERY_RESULTS_FIELDS = "jobReference,totalRows,pageToken,rows" class Project(object): @@ -293,7 +294,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) return page_iterator.HTTPIterator( @@ -371,7 +372,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) return page_iterator.HTTPIterator( @@ -1129,7 +1130,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) result = page_iterator.HTTPIterator( @@ -1207,7 +1208,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) result = page_iterator.HTTPIterator( @@ -1284,7 +1285,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) result = page_iterator.HTTPIterator( @@ -1510,7 +1511,7 @@ def delete_table( raise def _get_query_results( - self, job_id, retry, project=None, timeout_ms=None, location=None, timeout=None + self, job_id, retry, project=None, timeout_ms=None, location=None, timeout=None, ): """Get the query results object for a query job. @@ -1890,7 +1891,7 @@ def api_request(*args, **kwargs): span_attributes=span_attributes, *args, timeout=timeout, - **kwargs + **kwargs, ) return page_iterator.HTTPIterator( @@ -2374,7 +2375,7 @@ def load_table_from_json( destination = _table_arg_to_table_ref(destination, default_project=self.project) - data_str = u"\n".join(json.dumps(item) for item in json_rows) + data_str = "\n".join(json.dumps(item) for item in json_rows) encoded_str = data_str.encode() data_file = io.BytesIO(encoded_str) return self.load_table_from_file( @@ -3169,6 +3170,83 @@ def list_rows( # Pass in selected_fields separately from schema so that full # tables can be fetched without a column filter. selected_fields=selected_fields, + total_rows=getattr(table, "num_rows", None), + ) + return row_iterator + + def _list_rows_from_query_results( + self, + job_id, + location, + project, + schema, + total_rows=None, + destination=None, + max_results=None, + start_index=None, + page_size=None, + retry=DEFAULT_RETRY, + timeout=None, + ): + """List the rows of a completed query. + See + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/getQueryResults + Args: + job_id (str): + ID of a query job. + location (str): Location of the query job. + project (str): + ID of the project where the query job was run. + schema (Sequence[google.cloud.bigquery.schema.SchemaField]): + The fields expected in these query results. Used to convert + from JSON to expected Python types. + total_rows (Optional[int]): + Total number of rows in the query results. + destination (Optional[Union[ \ + google.cloud.bigquery.table.Table, \ + google.cloud.bigquery.table.TableListItem, \ + google.cloud.bigquery.table.TableReference, \ + str, \ + ]]): + Destination table reference. Used to fetch the query results + with the BigQuery Storage API. + max_results (Optional[int]): + Maximum number of rows to return across the whole iterator. + start_index (Optional[int]): + The zero-based index of the starting row to read. + page_size (Optional[int]): + The maximum number of rows in each page of results from this request. + Non-positive values are ignored. Defaults to a sensible value set by the API. + retry (Optional[google.api_core.retry.Retry]): + How to retry the RPC. + timeout (Optional[float]): + The number of seconds to wait for the underlying HTTP transport + before using ``retry``. + If multiple requests are made under the hood, ``timeout`` + applies to each individual request. + Returns: + google.cloud.bigquery.table.RowIterator: + Iterator of row data + :class:`~google.cloud.bigquery.table.Row`-s. + """ + params = { + "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, + "location": location, + } + + if start_index is not None: + params["startIndex"] = start_index + + row_iterator = RowIterator( + client=self, + api_request=functools.partial(self._call_api, retry, timeout=timeout), + path=f"/projects/{project}/queries/{job_id}", + schema=schema, + max_results=max_results, + page_size=page_size, + table=destination, + extra_params=params, + total_rows=total_rows, ) return row_iterator diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index e25077360..1e2002eab 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -38,7 +38,6 @@ from google.cloud.bigquery.table import _EmptyRowIterator from google.cloud.bigquery.table import RangePartitioning from google.cloud.bigquery.table import _table_arg_to_table_ref -from google.cloud.bigquery.table import Table from google.cloud.bigquery.table import TableReference from google.cloud.bigquery.table import TimePartitioning @@ -1159,12 +1158,13 @@ def result( if self._query_results.total_rows is None: return _EmptyRowIterator() - schema = self._query_results.schema - dest_table_ref = self.destination - dest_table = Table(dest_table_ref, schema=schema) - dest_table._properties["numRows"] = self._query_results.total_rows - rows = self._client.list_rows( - dest_table, + rows = self._client._list_rows_from_query_results( + self._query_results.job_id, + self.location, + self._query_results.project, + self._query_results.schema, + total_rows=self._query_results.total_rows, + destination=self.destination, page_size=page_size, max_results=max_results, start_index=start_index, diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index d6d966eee..e46b7e3cd 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1306,6 +1306,8 @@ class RowIterator(HTTPIterator): call the BigQuery Storage API to fetch rows. selected_fields (Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]): A subset of columns to select from this table. + total_rows (Optional[int]): + Total number of rows in the table. """ @@ -1321,6 +1323,7 @@ def __init__( extra_params=None, table=None, selected_fields=None, + total_rows=None, ): super(RowIterator, self).__init__( client, @@ -1342,7 +1345,7 @@ def __init__( self._schema = schema self._selected_fields = selected_fields self._table = table - self._total_rows = getattr(table, "num_rows", None) + self._total_rows = total_rows def _get_next_page_response(self): """Requests the next page from the path provided. @@ -1419,7 +1422,7 @@ def _to_arrow_iterable(self, bqstorage_client=None): selected_fields=self._selected_fields, ) tabledata_list_download = functools.partial( - _pandas_helpers.download_arrow_tabledata_list, iter(self.pages), self.schema + _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema ) return self._to_page_iterable( bqstorage_download, @@ -1496,7 +1499,7 @@ def to_arrow( ) and self.max_results is not None: warnings.warn( "Cannot use bqstorage_client if max_results is set, " - "reverting to fetching data with the tabledata.list endpoint.", + "reverting to fetching data with the REST endpoint.", stacklevel=2, ) create_bqstorage_client = False @@ -1582,7 +1585,7 @@ def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None): selected_fields=self._selected_fields, ) tabledata_list_download = functools.partial( - _pandas_helpers.download_dataframe_tabledata_list, + _pandas_helpers.download_dataframe_row_iterator, iter(self.pages), self.schema, dtypes, @@ -1680,7 +1683,7 @@ def to_dataframe( ) and self.max_results is not None: warnings.warn( "Cannot use bqstorage_client if max_results is set, " - "reverting to fetching data with the tabledata.list endpoint.", + "reverting to fetching data with the REST endpoint.", stacklevel=2, ) create_bqstorage_client = False @@ -2167,7 +2170,7 @@ def _item_to_row(iterator, resource): ) -def _tabledata_list_page_columns(schema, response): +def _row_iterator_page_columns(schema, response): """Make a generator of all the columns in a page from tabledata.list. This enables creating a :class:`pandas.DataFrame` and other @@ -2197,7 +2200,7 @@ def _rows_page_start(iterator, page, response): """ # Make a (lazy) copy of the page in column-oriented format for use in data # science packages. - page._columns = _tabledata_list_page_columns(iterator._schema, response) + page._columns = _row_iterator_page_columns(iterator._schema, response) total_rows = response.get("totalRows") if total_rows is not None: diff --git a/tests/unit/job/helpers.py b/tests/unit/job/helpers.py index f928054f6..ea071c5ac 100644 --- a/tests/unit/job/helpers.py +++ b/tests/unit/job/helpers.py @@ -60,6 +60,7 @@ def _make_job_resource( endpoint="https://bigquery.googleapis.com", job_type="load", job_id="a-random-id", + location="US", project_id="some-project", user_email="bq-user@example.com", ): @@ -69,7 +70,11 @@ def _make_job_resource( "statistics": {"creationTime": creation_time_ms, job_type: {}}, "etag": etag, "id": "{}:{}".format(project_id, job_id), - "jobReference": {"projectId": project_id, "jobId": job_id}, + "jobReference": { + "projectId": project_id, + "jobId": job_id, + "location": location, + }, "selfLink": "{}/bigquery/v2/projects/{}/jobs/{}".format( endpoint, project_id, job_id ), @@ -130,7 +135,7 @@ def _table_ref(self, table_id): return TableReference(self.DS_REF, table_id) - def _make_resource(self, started=False, ended=False): + def _make_resource(self, started=False, ended=False, location="US"): self._setUpConstants() return _make_job_resource( creation_time_ms=int(self.WHEN_TS * 1000), @@ -144,6 +149,7 @@ def _make_resource(self, started=False, ended=False): job_id=self.JOB_ID, project_id=self.PROJECT, user_email=self.USER_EMAIL, + location=location, ) def _verifyInitialReadonlyProperties(self, job): diff --git a/tests/unit/job/test_base.py b/tests/unit/job/test_base.py index 90d4388b8..12e2d4b8b 100644 --- a/tests/unit/job/test_base.py +++ b/tests/unit/job/test_base.py @@ -882,10 +882,14 @@ def test_done_already(self): def test_result_default_wo_state(self): begun_job_resource = _make_job_resource( - job_id=self.JOB_ID, project_id=self.PROJECT, started=True + job_id=self.JOB_ID, project_id=self.PROJECT, location="US", started=True ) done_job_resource = _make_job_resource( - job_id=self.JOB_ID, project_id=self.PROJECT, started=True, ended=True + job_id=self.JOB_ID, + project_id=self.PROJECT, + location="US", + started=True, + ended=True, ) conn = _make_connection( _make_retriable_exception(), @@ -907,7 +911,7 @@ def test_result_default_wo_state(self): reload_call = mock.call( method="GET", path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}", - query_params={}, + query_params={"location": "US"}, timeout=None, ) conn.api_request.assert_has_calls( @@ -916,38 +920,48 @@ def test_result_default_wo_state(self): def test_result_w_retry_wo_state(self): begun_job_resource = _make_job_resource( - job_id=self.JOB_ID, project_id=self.PROJECT, started=True + job_id=self.JOB_ID, project_id=self.PROJECT, location="EU", started=True ) done_job_resource = _make_job_resource( - job_id=self.JOB_ID, project_id=self.PROJECT, started=True, ended=True + job_id=self.JOB_ID, + project_id=self.PROJECT, + location="EU", + started=True, + ended=True, ) conn = _make_connection( exceptions.NotFound("not normally retriable"), begun_job_resource, - # The call to done() / reload() does not get the custom retry - # policy passed to it, so we don't throw a non-retriable - # exception here. See: - # https://github.com/googleapis/python-bigquery/issues/24 - _make_retriable_exception(), + exceptions.NotFound("not normally retriable"), done_job_resource, ) client = _make_client(project=self.PROJECT, connection=conn) - job = self._make_one(self.JOB_ID, client) + job = self._make_one( + self._job_reference(self.JOB_ID, self.PROJECT, "EU"), client + ) custom_predicate = mock.Mock() custom_predicate.return_value = True - custom_retry = google.api_core.retry.Retry(predicate=custom_predicate) + custom_retry = google.api_core.retry.Retry( + predicate=custom_predicate, initial=0.001, maximum=0.001, deadline=0.001, + ) self.assertIs(job.result(retry=custom_retry), job) begin_call = mock.call( method="POST", path=f"/projects/{self.PROJECT}/jobs", - data={"jobReference": {"jobId": self.JOB_ID, "projectId": self.PROJECT}}, + data={ + "jobReference": { + "jobId": self.JOB_ID, + "projectId": self.PROJECT, + "location": "EU", + } + }, timeout=None, ) reload_call = mock.call( method="GET", path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}", - query_params={}, + query_params={"location": "EU"}, timeout=None, ) conn.api_request.assert_has_calls( diff --git a/tests/unit/job/test_query.py b/tests/unit/job/test_query.py index c0b90d8ea..daaf2e557 100644 --- a/tests/unit/job/test_query.py +++ b/tests/unit/job/test_query.py @@ -23,6 +23,7 @@ import requests from six.moves import http_client +from google.cloud.bigquery.client import _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS import google.cloud.bigquery.query from .helpers import _Base from .helpers import _make_client @@ -40,8 +41,10 @@ def _get_target_class(): return QueryJob - def _make_resource(self, started=False, ended=False): - resource = super(TestQueryJob, self)._make_resource(started, ended) + def _make_resource(self, started=False, ended=False, location="US"): + resource = super(TestQueryJob, self)._make_resource( + started, ended, location=location + ) config = resource["configuration"]["query"] config["query"] = self.QUERY return resource @@ -770,22 +773,30 @@ def test_result(self): query_resource = { "jobComplete": False, - "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, + "jobReference": { + "projectId": self.PROJECT, + "jobId": self.JOB_ID, + "location": "EU", + }, } query_resource_done = { "jobComplete": True, - "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, + "jobReference": { + "projectId": self.PROJECT, + "jobId": self.JOB_ID, + "location": "EU", + }, "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "2", } - job_resource = self._make_resource(started=True) - job_resource_done = self._make_resource(started=True, ended=True) + job_resource = self._make_resource(started=True, location="EU") + job_resource_done = self._make_resource(started=True, ended=True, location="EU") job_resource_done["configuration"]["query"]["destinationTable"] = { "projectId": "dest-project", "datasetId": "dest_dataset", "tableId": "dest_table", } - tabledata_resource = { + query_page_resource = { # Explicitly set totalRows to be different from the initial # response to test update during iteration. "totalRows": "1", @@ -793,7 +804,7 @@ def test_result(self): "rows": [{"f": [{"v": "abc"}]}], } conn = _make_connection( - query_resource, query_resource_done, job_resource_done, tabledata_resource + query_resource, query_resource_done, job_resource_done, query_page_resource ) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) @@ -809,26 +820,30 @@ def test_result(self): # on the response from tabledata.list. self.assertEqual(result.total_rows, 1) + query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" query_results_call = mock.call( method="GET", - path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}", - query_params={"maxResults": 0}, + path=query_results_path, + query_params={"maxResults": 0, "location": "EU"}, timeout=None, ) reload_call = mock.call( method="GET", path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}", - query_params={}, + query_params={"location": "EU"}, timeout=None, ) - tabledata_call = mock.call( + query_page_call = mock.call( method="GET", - path="/projects/dest-project/datasets/dest_dataset/tables/dest_table/data", - query_params={}, + path=query_results_path, + query_params={ + "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, + "location": "EU", + }, timeout=None, ) conn.api_request.assert_has_calls( - [query_results_call, query_results_call, reload_call, tabledata_call] + [query_results_call, query_results_call, reload_call, query_page_call] ) def test_result_with_done_job_calls_get_query_results(self): @@ -838,18 +853,18 @@ def test_result_with_done_job_calls_get_query_results(self): "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "1", } - job_resource = self._make_resource(started=True, ended=True) + job_resource = self._make_resource(started=True, ended=True, location="EU") job_resource["configuration"]["query"]["destinationTable"] = { "projectId": "dest-project", "datasetId": "dest_dataset", "tableId": "dest_table", } - tabledata_resource = { + results_page_resource = { "totalRows": "1", "pageToken": None, "rows": [{"f": [{"v": "abc"}]}], } - conn = _make_connection(query_resource_done, tabledata_resource) + conn = _make_connection(query_resource_done, results_page_resource) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) @@ -859,19 +874,23 @@ def test_result_with_done_job_calls_get_query_results(self): self.assertEqual(len(rows), 1) self.assertEqual(rows[0].col1, "abc") + query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" query_results_call = mock.call( method="GET", - path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}", - query_params={"maxResults": 0}, + path=query_results_path, + query_params={"maxResults": 0, "location": "EU"}, timeout=None, ) - tabledata_call = mock.call( + query_results_page_call = mock.call( method="GET", - path="/projects/dest-project/datasets/dest_dataset/tables/dest_table/data", - query_params={}, + path=query_results_path, + query_params={ + "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, + "location": "EU", + }, timeout=None, ) - conn.api_request.assert_has_calls([query_results_call, tabledata_call]) + conn.api_request.assert_has_calls([query_results_call, query_results_page_call]) def test_result_with_max_results(self): from google.cloud.bigquery.table import RowIterator @@ -882,7 +901,7 @@ def test_result_with_max_results(self): "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "5", } - tabledata_resource = { + query_page_resource = { "totalRows": "5", "pageToken": None, "rows": [ @@ -891,7 +910,7 @@ def test_result_with_max_results(self): {"f": [{"v": "ghi"}]}, ], } - connection = _make_connection(query_resource, tabledata_resource) + connection = _make_connection(query_resource, query_page_resource) client = _make_client(self.PROJECT, connection=connection) resource = self._make_resource(ended=True) job = self._get_target_class().from_api_repr(resource, client) @@ -907,9 +926,9 @@ def test_result_with_max_results(self): self.assertEqual(len(rows), 3) self.assertEqual(len(connection.api_request.call_args_list), 2) - tabledata_list_request = connection.api_request.call_args_list[1] + query_page_request = connection.api_request.call_args_list[1] self.assertEqual( - tabledata_list_request[1]["query_params"]["maxResults"], max_results + query_page_request[1]["query_params"]["maxResults"], max_results ) def test_result_w_retry(self): @@ -925,8 +944,10 @@ def test_result_w_retry(self): "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "2", } - job_resource = self._make_resource(started=True) - job_resource_done = self._make_resource(started=True, ended=True) + job_resource = self._make_resource(started=True, location="asia-northeast1") + job_resource_done = self._make_resource( + started=True, ended=True, location="asia-northeast1" + ) job_resource_done["configuration"]["query"]["destinationTable"] = { "projectId": "dest-project", "datasetId": "dest_dataset", @@ -958,13 +979,13 @@ def test_result_w_retry(self): query_results_call = mock.call( method="GET", path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}", - query_params={"maxResults": 0}, + query_params={"maxResults": 0, "location": "asia-northeast1"}, timeout=None, ) reload_call = mock.call( method="GET", path=f"/projects/{self.PROJECT}/jobs/{self.JOB_ID}", - query_params={}, + query_params={"location": "asia-northeast1"}, timeout=None, ) @@ -1059,14 +1080,14 @@ def test_result_w_page_size(self): "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, "totalRows": "4", } - job_resource = self._make_resource(started=True, ended=True) + job_resource = self._make_resource(started=True, ended=True, location="US") q_config = job_resource["configuration"]["query"] q_config["destinationTable"] = { "projectId": self.PROJECT, "datasetId": self.DS_ID, "tableId": self.TABLE_ID, } - tabledata_resource = { + query_page_resource = { "totalRows": 4, "pageToken": "some-page-token", "rows": [ @@ -1075,9 +1096,9 @@ def test_result_w_page_size(self): {"f": [{"v": "row3"}]}, ], } - tabledata_resource_page_2 = {"totalRows": 4, "rows": [{"f": [{"v": "row4"}]}]} + query_page_resource_2 = {"totalRows": 4, "rows": [{"f": [{"v": "row4"}]}]} conn = _make_connection( - query_results_resource, tabledata_resource, tabledata_resource_page_2 + query_results_resource, query_page_resource, query_page_resource_2 ) client = _make_client(self.PROJECT, connection=conn) job = self._get_target_class().from_api_repr(job_resource, client) @@ -1089,27 +1110,29 @@ def test_result_w_page_size(self): actual_rows = list(result) self.assertEqual(len(actual_rows), 4) - tabledata_path = "/projects/%s/datasets/%s/tables/%s/data" % ( - self.PROJECT, - self.DS_ID, - self.TABLE_ID, + query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}" + query_page_1_call = mock.call( + method="GET", + path=query_results_path, + query_params={ + "maxResults": 3, + "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, + "location": "US", + }, + timeout=None, ) - conn.api_request.assert_has_calls( - [ - mock.call( - method="GET", - path=tabledata_path, - query_params={"maxResults": 3}, - timeout=None, - ), - mock.call( - method="GET", - path=tabledata_path, - query_params={"pageToken": "some-page-token", "maxResults": 3}, - timeout=None, - ), - ] + query_page_2_call = mock.call( + method="GET", + path=query_results_path, + query_params={ + "pageToken": "some-page-token", + "maxResults": 3, + "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, + "location": "US", + }, + timeout=None, ) + conn.api_request.assert_has_calls([query_page_1_call, query_page_2_call]) def test_result_with_start_index(self): from google.cloud.bigquery.table import RowIterator diff --git a/tests/unit/test__pandas_helpers.py b/tests/unit/test__pandas_helpers.py index bdb1c56ea..ef0c40e1a 100644 --- a/tests/unit/test__pandas_helpers.py +++ b/tests/unit/test__pandas_helpers.py @@ -1202,7 +1202,7 @@ def test_dataframe_to_parquet_dict_sequence_schema(module_under_test): @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_arrow_tabledata_list_unknown_field_type(module_under_test): +def test_download_arrow_row_iterator_unknown_field_type(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1216,7 +1216,7 @@ def test_download_arrow_tabledata_list_unknown_field_type(module_under_test): schema.SchemaField("alien_field", "ALIEN_FLOAT_TYPE"), ] - results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema) + results_gen = module_under_test.download_arrow_row_iterator(pages, bq_schema) with warnings.catch_warnings(record=True) as warned: result = next(results_gen) @@ -1238,7 +1238,7 @@ def test_download_arrow_tabledata_list_unknown_field_type(module_under_test): @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_arrow_tabledata_list_known_field_type(module_under_test): +def test_download_arrow_row_iterator_known_field_type(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1252,7 +1252,7 @@ def test_download_arrow_tabledata_list_known_field_type(module_under_test): schema.SchemaField("non_alien_field", "STRING"), ] - results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema) + results_gen = module_under_test.download_arrow_row_iterator(pages, bq_schema) with warnings.catch_warnings(record=True) as warned: result = next(results_gen) @@ -1273,7 +1273,7 @@ def test_download_arrow_tabledata_list_known_field_type(module_under_test): @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test): +def test_download_arrow_row_iterator_dict_sequence_schema(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1287,7 +1287,7 @@ def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test): {"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"}, ] - results_gen = module_under_test.download_arrow_tabledata_list(pages, dict_schema) + results_gen = module_under_test.download_arrow_row_iterator(pages, dict_schema) result = next(results_gen) assert len(result.columns) == 2 @@ -1301,7 +1301,7 @@ def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test): @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") -def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_test): +def test_download_dataframe_row_iterator_dict_sequence_schema(module_under_test): fake_page = api_core.page_iterator.Page( parent=mock.Mock(), items=[{"page_data": "foo"}], @@ -1315,7 +1315,7 @@ def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_tes {"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"}, ] - results_gen = module_under_test.download_dataframe_tabledata_list( + results_gen = module_under_test.download_dataframe_row_iterator( pages, dict_schema, dtypes={} ) result = next(results_gen) @@ -1335,5 +1335,5 @@ def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_tes def test_table_data_listpage_to_dataframe_skips_stop_iteration(module_under_test): - dataframe = module_under_test._tabledata_list_page_to_dataframe([], [], {}) + dataframe = module_under_test._row_iterator_page_to_dataframe([], [], {}) assert isinstance(dataframe, pandas.DataFrame) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e507834f6..ca2f7ea66 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -6786,12 +6786,17 @@ def _bigquery_timestamp_float_repr(ts_float): age = SchemaField("age", "INTEGER", mode="NULLABLE") joined = SchemaField("joined", "TIMESTAMP", mode="NULLABLE") table = Table(self.TABLE_REF, schema=[full_name, age, joined]) + table._properties["numRows"] = 7 iterator = client.list_rows(table, timeout=7.5) + + # Check that initial total_rows is populated from the table. + self.assertEqual(iterator.total_rows, 7) page = six.next(iterator.pages) rows = list(page) - total_rows = iterator.total_rows - page_token = iterator.next_page_token + + # Check that total_rows is updated based on API response. + self.assertEqual(iterator.total_rows, ROWS) f2i = {"full_name": 0, "age": 1, "joined": 2} self.assertEqual(len(rows), 4) @@ -6799,8 +6804,7 @@ def _bigquery_timestamp_float_repr(ts_float): self.assertEqual(rows[1], Row(("Bharney Rhubble", 33, WHEN_1), f2i)) self.assertEqual(rows[2], Row(("Wylma Phlyntstone", 29, WHEN_2), f2i)) self.assertEqual(rows[3], Row(("Bhettye Rhubble", None, None), f2i)) - self.assertEqual(total_rows, ROWS) - self.assertEqual(page_token, TOKEN) + self.assertEqual(iterator.next_page_token, TOKEN) conn.api_request.assert_called_once_with( method="GET", path="/%s" % PATH, query_params={}, timeout=7.5 diff --git a/tests/unit/test_magics.py b/tests/unit/test_magics.py index b2877845a..a7cf92919 100644 --- a/tests/unit/test_magics.py +++ b/tests/unit/test_magics.py @@ -170,7 +170,7 @@ def test_context_with_default_connection(): default_conn = make_connection(QUERY_RESOURCE, QUERY_RESULTS_RESOURCE) conn_patch = mock.patch("google.cloud.bigquery.client.Connection", autospec=True) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) @@ -235,7 +235,7 @@ def test_context_with_custom_connection(): default_conn = make_connection() conn_patch = mock.patch("google.cloud.bigquery.client.Connection", autospec=True) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) @@ -1078,7 +1078,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_overrides_context(param_value, ex ) conn = magics.context._connection = make_connection(resource, query_results, data) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) with list_rows_patch, default_patch: @@ -1117,7 +1117,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_w_context_inplace(): ) conn = magics.context._connection = make_connection(resource, query_results, data) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) with list_rows_patch, default_patch: @@ -1156,7 +1156,7 @@ def test_bigquery_magic_w_maximum_bytes_billed_w_context_setter(): ) conn = magics.context._connection = make_connection(resource, query_results, data) list_rows_patch = mock.patch( - "google.cloud.bigquery.client.Client.list_rows", + "google.cloud.bigquery.client.Client._list_rows_from_query_results", return_value=google.cloud.bigquery.table._EmptyRowIterator(), ) with list_rows_patch, default_patch: diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index e21453b9f..e232f32e6 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -1572,10 +1572,7 @@ def test_constructor_with_table(self): from google.cloud.bigquery.table import Table table = Table("proj.dset.tbl") - table._properties["numRows"] = 100 - - iterator = self._make_one(table=table) - + iterator = self._make_one(table=table, total_rows=100) self.assertIs(iterator._table, table) self.assertEqual(iterator.total_rows, 100) @@ -1883,7 +1880,7 @@ def test_to_arrow_max_results_w_create_bqstorage_warning(self): for warning in warned if warning.category is UserWarning and "cannot use bqstorage_client" in str(warning).lower() - and "tabledata.list" in str(warning) + and "REST" in str(warning) ] self.assertEqual(len(matches), 1, msg="User warning was not emitted.") mock_client._create_bqstorage_client.assert_not_called() @@ -2667,7 +2664,7 @@ def test_to_dataframe_max_results_w_bqstorage_warning(self): for warning in warned if warning.category is UserWarning and "cannot use bqstorage_client" in str(warning).lower() - and "tabledata.list" in str(warning) + and "REST" in str(warning) ] self.assertEqual(len(matches), 1, msg="User warning was not emitted.") @@ -2703,7 +2700,7 @@ def test_to_dataframe_max_results_w_create_bqstorage_warning(self): for warning in warned if warning.category is UserWarning and "cannot use bqstorage_client" in str(warning).lower() - and "tabledata.list" in str(warning) + and "REST" in str(warning) ] self.assertEqual(len(matches), 1, msg="User warning was not emitted.") mock_client._create_bqstorage_client.assert_not_called()