diff --git a/google/cloud/bigquery/dbapi/cursor.py b/google/cloud/bigquery/dbapi/cursor.py index eb73b3d56..408fdfe55 100644 --- a/google/cloud/bigquery/dbapi/cursor.py +++ b/google/cloud/bigquery/dbapi/cursor.py @@ -15,6 +15,7 @@ """Cursor for the Google BigQuery DB-API.""" import collections +import copy try: from collections import abc as collections_abc @@ -26,6 +27,8 @@ import six from google.cloud.bigquery import job +from google.cloud.bigquery import schema +from google.cloud.bigquery import table from google.cloud.bigquery.dbapi import _helpers from google.cloud.bigquery.dbapi import exceptions import google.cloud.exceptions @@ -89,18 +92,16 @@ def _set_description(self, schema): return self.description = tuple( - [ - Column( - name=field.name, - type_code=field.field_type, - display_size=None, - internal_size=None, - precision=None, - scale=None, - null_ok=field.is_nullable, - ) - for field in schema - ] + Column( + name=field.name, + type_code=field.field_type, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=field.is_nullable, + ) + for field in schema ) def _set_rowcount(self, query_results): @@ -169,12 +170,27 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None): formatted_operation = _format_operation(operation, parameters=parameters) query_parameters = _helpers.to_query_parameters(parameters) - config = job_config or job.QueryJobConfig(use_legacy_sql=False) + if client._default_query_job_config: + if job_config: + config = job_config._fill_from_default(client._default_query_job_config) + else: + config = copy.deepcopy(client._default_query_job_config) + else: + config = job_config or job.QueryJobConfig(use_legacy_sql=False) + config.query_parameters = query_parameters self._query_job = client.query( formatted_operation, job_config=config, job_id=job_id ) + if self._query_job.dry_run: + schema_field = schema.SchemaField( + name="estimated_bytes", field_type="INTEGER", mode="REQUIRED", + ) + self._set_description(schema=[schema_field]) + self.rowcount = 1 + return + # Wait for the query to finish. try: self._query_job.result() @@ -207,6 +223,12 @@ def _try_fetch(self, size=None): "No query results: execute() must be called before fetch." ) + if self._query_job.dry_run: + estimated_bytes = self._query_job.total_bytes_processed + row = table.Row((estimated_bytes,), {"estimated_bytes": 0}) + self._query_data = iter([row]) + return + is_dml = ( self._query_job.statement_type and self._query_job.statement_type.upper() != "SELECT" @@ -290,6 +312,11 @@ def _bqstorage_fetch(self, bqstorage_client): def fetchone(self): """Fetch a single row from the results of the last ``execute*()`` call. + .. note:: + If a dry run query was executed, a row with a single value is + returned representing the estimated number of bytes that would be + processed by the query. + Returns: Tuple: A tuple representing a row or ``None`` if no more data is @@ -307,6 +334,11 @@ def fetchone(self): def fetchmany(self, size=None): """Fetch multiple results from the last ``execute*()`` call. + .. note:: + If a dry run query was executed, a row with a single value is + returned representing the estimated number of bytes that would be + processed by the query. + .. note:: The size parameter is not used for the request/response size. Set the ``arraysize`` attribute before calling ``execute()`` to @@ -343,6 +375,11 @@ def fetchmany(self, size=None): def fetchall(self): """Fetch all remaining results from the last ``execute*()`` call. + .. note:: + If a dry run query was executed, a row with a single value is + returned representing the estimated number of bytes that would be + processed by the query. + Returns: List[Tuple]: A list of all the rows in the results. diff --git a/tests/system.py b/tests/system.py index 66d7ee259..8b0b9c401 100644 --- a/tests/system.py +++ b/tests/system.py @@ -1782,6 +1782,24 @@ def test_dbapi_fetch_w_bqstorage_client_large_result_set(self): ] self.assertEqual(fetched_data, expected_data) + def test_dbapi_dry_run_query(self): + from google.cloud.bigquery.job import QueryJobConfig + + query = """ + SELECT country_name + FROM `bigquery-public-data.utility_us.country_code_iso` + WHERE country_name LIKE 'U%' + """ + + Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True)) + self.assertEqual(Config.CURSOR.rowcount, 1, "expected a single row") + + rows = Config.CURSOR.fetchall() + + row_tuples = [r.values() for r in rows] + expected = [(3473,)] + self.assertEqual(row_tuples, expected) + def _load_table_for_dml(self, rows, dataset_id, table_id): from google.cloud._testing import _NamedTemporaryFile from google.cloud.bigquery.job import CreateDisposition diff --git a/tests/unit/test_dbapi_cursor.py b/tests/unit/test_dbapi_cursor.py index e53cc158a..8259af24c 100644 --- a/tests/unit/test_dbapi_cursor.py +++ b/tests/unit/test_dbapi_cursor.py @@ -36,7 +36,15 @@ def _get_target_class(): def _make_one(self, *args, **kw): return self._get_target_class()(*args, **kw) - def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None): + def _mock_client( + self, + rows=None, + schema=None, + num_dml_affected_rows=None, + default_query_job_config=None, + dry_run_job=False, + total_bytes_processed=0, + ): from google.cloud.bigquery import client if rows is None: @@ -49,8 +57,12 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None): total_rows=total_rows, schema=schema, num_dml_affected_rows=num_dml_affected_rows, + dry_run=dry_run_job, + total_bytes_processed=total_bytes_processed, ) mock_client.list_rows.return_value = rows + mock_client._default_query_job_config = default_query_job_config + return mock_client def _mock_bqstorage_client(self, rows=None, stream_count=0): @@ -76,18 +88,31 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0): return mock_client - def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None): + def _mock_job( + self, + total_rows=0, + schema=None, + num_dml_affected_rows=None, + dry_run=False, + total_bytes_processed=0, + ): from google.cloud.bigquery import job mock_job = mock.create_autospec(job.QueryJob) mock_job.error_result = None mock_job.state = "DONE" - mock_job.result.return_value = mock_job - mock_job._query_results = self._mock_results( - total_rows=total_rows, - schema=schema, - num_dml_affected_rows=num_dml_affected_rows, - ) + mock_job.dry_run = dry_run + + if dry_run: + mock_job.result.side_effect = exceptions.NotFound + mock_job.total_bytes_processed = total_bytes_processed + else: + mock_job.result.return_value = mock_job + mock_job._query_results = self._mock_results( + total_rows=total_rows, + schema=schema, + num_dml_affected_rows=num_dml_affected_rows, + ) if num_dml_affected_rows is None: mock_job.statement_type = None # API sends back None for SELECT @@ -373,7 +398,27 @@ def test_execute_custom_job_id(self): self.assertEqual(args[0], "SELECT 1;") self.assertEqual(kwargs["job_id"], "foo") - def test_execute_custom_job_config(self): + def test_execute_w_default_config(self): + from google.cloud.bigquery.dbapi import connect + from google.cloud.bigquery import job + + default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True) + client = self._mock_client( + rows=[], num_dml_affected_rows=0, default_query_job_config=default_config + ) + connection = connect(client) + cursor = connection.cursor() + + cursor.execute("SELECT 1;", job_id="foo") + + _, kwargs = client.query.call_args + used_config = kwargs["job_config"] + expected_config = job.QueryJobConfig( + use_legacy_sql=False, flatten_results=True, query_parameters=[] + ) + self.assertEqual(used_config._properties, expected_config._properties) + + def test_execute_custom_job_config_wo_default_config(self): from google.cloud.bigquery.dbapi import connect from google.cloud.bigquery import job @@ -387,6 +432,29 @@ def test_execute_custom_job_config(self): self.assertEqual(kwargs["job_id"], "foo") self.assertEqual(kwargs["job_config"], config) + def test_execute_custom_job_config_w_default_config(self): + from google.cloud.bigquery.dbapi import connect + from google.cloud.bigquery import job + + default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True) + client = self._mock_client( + rows=[], num_dml_affected_rows=0, default_query_job_config=default_config + ) + connection = connect(client) + cursor = connection.cursor() + config = job.QueryJobConfig(use_legacy_sql=True) + + cursor.execute("SELECT 1;", job_id="foo", job_config=config) + + _, kwargs = client.query.call_args + used_config = kwargs["job_config"] + expected_config = job.QueryJobConfig( + use_legacy_sql=True, # the config passed to execute() prevails + flatten_results=True, # from the default + query_parameters=[], + ) + self.assertEqual(used_config._properties, expected_config._properties) + def test_execute_w_dml(self): from google.cloud.bigquery.dbapi import connect @@ -442,6 +510,52 @@ def test_execute_w_query(self): row = cursor.fetchone() self.assertIsNone(row) + def test_execute_w_query_dry_run(self): + from google.cloud.bigquery.job import QueryJobConfig + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery import dbapi + + connection = dbapi.connect( + self._mock_client( + rows=[("hello", "world", 1), ("howdy", "y'all", 2)], + schema=[ + SchemaField("a", "STRING", mode="NULLABLE"), + SchemaField("b", "STRING", mode="REQUIRED"), + SchemaField("c", "INTEGER", mode="NULLABLE"), + ], + dry_run_job=True, + total_bytes_processed=12345, + ) + ) + cursor = connection.cursor() + + cursor.execute( + "SELECT a, b, c FROM hello_world WHERE d > 3;", + job_config=QueryJobConfig(dry_run=True), + ) + + expected_description = ( + dbapi.cursor.Column( + name="estimated_bytes", + type_code="INTEGER", + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=False, + ), + ) + self.assertEqual(cursor.description, expected_description) + self.assertEqual(cursor.rowcount, 1) + + rows = cursor.fetchall() + + # We expect a single row with one column - the estimated numbe of bytes + # that will be processed by the query. + self.assertEqual(len(rows), 1) + self.assertEqual(len(rows[0]), 1) + self.assertEqual(rows[0][0], 12345) + def test_execute_raises_if_result_raises(self): import google.cloud.exceptions @@ -451,8 +565,10 @@ def test_execute_raises_if_result_raises(self): from google.cloud.bigquery.dbapi import exceptions job = mock.create_autospec(job.QueryJob) + job.dry_run = None job.result.side_effect = google.cloud.exceptions.GoogleCloudError("") client = mock.create_autospec(client.Client) + client._default_query_job_config = None client.query.return_value = job connection = connect(client) cursor = connection.cursor()