diff --git a/google/cloud/bigquery/dbapi/cursor.py b/google/cloud/bigquery/dbapi/cursor.py index 40de336bd..7a10637f0 100644 --- a/google/cloud/bigquery/dbapi/cursor.py +++ b/google/cloud/bigquery/dbapi/cursor.py @@ -15,6 +15,7 @@ """Cursor for the Google BigQuery DB-API.""" import collections +import copy import warnings try: @@ -93,18 +94,16 @@ def _set_description(self, schema): return self.description = tuple( - [ - Column( - name=field.name, - type_code=field.field_type, - display_size=None, - internal_size=None, - precision=None, - scale=None, - null_ok=field.is_nullable, - ) - for field in schema - ] + Column( + name=field.name, + type_code=field.field_type, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=field.is_nullable, + ) + for field in schema ) def _set_rowcount(self, query_results): @@ -173,12 +172,24 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None): formatted_operation = _format_operation(operation, parameters=parameters) query_parameters = _helpers.to_query_parameters(parameters) - config = job_config or job.QueryJobConfig(use_legacy_sql=False) + if client._default_query_job_config: + if job_config: + config = job_config._fill_from_default(client._default_query_job_config) + else: + config = copy.deepcopy(client._default_query_job_config) + else: + config = job_config or job.QueryJobConfig(use_legacy_sql=False) + config.query_parameters = query_parameters self._query_job = client.query( formatted_operation, job_config=config, job_id=job_id ) + if self._query_job.dry_run: + self._set_description(schema=None) + self.rowcount = 0 + return + # Wait for the query to finish. try: self._query_job.result() @@ -211,6 +222,10 @@ def _try_fetch(self, size=None): "No query results: execute() must be called before fetch." ) + if self._query_job.dry_run: + self._query_data = iter([]) + return + is_dml = ( self._query_job.statement_type and self._query_job.statement_type.upper() != "SELECT" @@ -307,6 +322,9 @@ def _bqstorage_fetch(self, bqstorage_client): def fetchone(self): """Fetch a single row from the results of the last ``execute*()`` call. + .. note:: + If a dry run query was executed, no rows are returned. + Returns: Tuple: A tuple representing a row or ``None`` if no more data is @@ -324,6 +342,9 @@ def fetchone(self): def fetchmany(self, size=None): """Fetch multiple results from the last ``execute*()`` call. + .. note:: + If a dry run query was executed, no rows are returned. + .. note:: The size parameter is not used for the request/response size. Set the ``arraysize`` attribute before calling ``execute()`` to @@ -360,6 +381,9 @@ def fetchmany(self, size=None): def fetchall(self): """Fetch all remaining results from the last ``execute*()`` call. + .. note:: + If a dry run query was executed, no rows are returned. + Returns: List[Tuple]: A list of all the rows in the results. diff --git a/tests/system.py b/tests/system.py index 965c34331..14d3f49a1 100644 --- a/tests/system.py +++ b/tests/system.py @@ -1782,6 +1782,22 @@ def test_dbapi_fetch_w_bqstorage_client_v1beta1_large_result_set(self): ] self.assertEqual(fetched_data, expected_data) + def test_dbapi_dry_run_query(self): + from google.cloud.bigquery.job import QueryJobConfig + + query = """ + SELECT country_name + FROM `bigquery-public-data.utility_us.country_code_iso` + WHERE country_name LIKE 'U%' + """ + + Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True)) + self.assertEqual(Config.CURSOR.rowcount, 0, "expected no rows") + + rows = Config.CURSOR.fetchall() + + self.assertEqual(list(rows), []) + @unittest.skipIf( bigquery_storage_v1 is None, "Requires `google-cloud-bigquery-storage`" ) diff --git a/tests/unit/test_dbapi_cursor.py b/tests/unit/test_dbapi_cursor.py index 129ce28ad..bd1d9dc0a 100644 --- a/tests/unit/test_dbapi_cursor.py +++ b/tests/unit/test_dbapi_cursor.py @@ -46,7 +46,15 @@ def _get_target_class(): def _make_one(self, *args, **kw): return self._get_target_class()(*args, **kw) - def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None): + def _mock_client( + self, + rows=None, + schema=None, + num_dml_affected_rows=None, + default_query_job_config=None, + dry_run_job=False, + total_bytes_processed=0, + ): from google.cloud.bigquery import client if rows is None: @@ -59,8 +67,11 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None): total_rows=total_rows, schema=schema, num_dml_affected_rows=num_dml_affected_rows, + dry_run=dry_run_job, + total_bytes_processed=total_bytes_processed, ) mock_client.list_rows.return_value = rows + mock_client._default_query_job_config = default_query_job_config # Assure that the REST client gets used, not the BQ Storage client. mock_client._create_bqstorage_client.return_value = None @@ -95,27 +106,41 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0, v1beta1=False): ) mock_client.create_read_session.return_value = mock_read_session + mock_rows_stream = mock.MagicMock() mock_rows_stream.rows.return_value = iter(rows) mock_client.read_rows.return_value = mock_rows_stream return mock_client - def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None): + def _mock_job( + self, + total_rows=0, + schema=None, + num_dml_affected_rows=None, + dry_run=False, + total_bytes_processed=0, + ): from google.cloud.bigquery import job mock_job = mock.create_autospec(job.QueryJob) mock_job.error_result = None mock_job.state = "DONE" - mock_job.result.return_value = mock_job - mock_job._query_results = self._mock_results( - total_rows=total_rows, - schema=schema, - num_dml_affected_rows=num_dml_affected_rows, - ) - mock_job.destination.to_bqstorage.return_value = ( - "projects/P/datasets/DS/tables/T" - ) + mock_job.dry_run = dry_run + + if dry_run: + mock_job.result.side_effect = exceptions.NotFound + mock_job.total_bytes_processed = total_bytes_processed + else: + mock_job.result.return_value = mock_job + mock_job._query_results = self._mock_results( + total_rows=total_rows, + schema=schema, + num_dml_affected_rows=num_dml_affected_rows, + ) + mock_job.destination.to_bqstorage.return_value = ( + "projects/P/datasets/DS/tables/T" + ) if num_dml_affected_rows is None: mock_job.statement_type = None # API sends back None for SELECT @@ -445,7 +470,27 @@ def test_execute_custom_job_id(self): self.assertEqual(args[0], "SELECT 1;") self.assertEqual(kwargs["job_id"], "foo") - def test_execute_custom_job_config(self): + def test_execute_w_default_config(self): + from google.cloud.bigquery.dbapi import connect + from google.cloud.bigquery import job + + default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True) + client = self._mock_client( + rows=[], num_dml_affected_rows=0, default_query_job_config=default_config + ) + connection = connect(client) + cursor = connection.cursor() + + cursor.execute("SELECT 1;", job_id="foo") + + _, kwargs = client.query.call_args + used_config = kwargs["job_config"] + expected_config = job.QueryJobConfig( + use_legacy_sql=False, flatten_results=True, query_parameters=[] + ) + self.assertEqual(used_config._properties, expected_config._properties) + + def test_execute_custom_job_config_wo_default_config(self): from google.cloud.bigquery.dbapi import connect from google.cloud.bigquery import job @@ -459,6 +504,29 @@ def test_execute_custom_job_config(self): self.assertEqual(kwargs["job_id"], "foo") self.assertEqual(kwargs["job_config"], config) + def test_execute_custom_job_config_w_default_config(self): + from google.cloud.bigquery.dbapi import connect + from google.cloud.bigquery import job + + default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True) + client = self._mock_client( + rows=[], num_dml_affected_rows=0, default_query_job_config=default_config + ) + connection = connect(client) + cursor = connection.cursor() + config = job.QueryJobConfig(use_legacy_sql=True) + + cursor.execute("SELECT 1;", job_id="foo", job_config=config) + + _, kwargs = client.query.call_args + used_config = kwargs["job_config"] + expected_config = job.QueryJobConfig( + use_legacy_sql=True, # the config passed to execute() prevails + flatten_results=True, # from the default + query_parameters=[], + ) + self.assertEqual(used_config._properties, expected_config._properties) + def test_execute_w_dml(self): from google.cloud.bigquery.dbapi import connect @@ -514,6 +582,35 @@ def test_execute_w_query(self): row = cursor.fetchone() self.assertIsNone(row) + def test_execute_w_query_dry_run(self): + from google.cloud.bigquery.job import QueryJobConfig + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery import dbapi + + connection = dbapi.connect( + self._mock_client( + rows=[("hello", "world", 1), ("howdy", "y'all", 2)], + schema=[ + SchemaField("a", "STRING", mode="NULLABLE"), + SchemaField("b", "STRING", mode="REQUIRED"), + SchemaField("c", "INTEGER", mode="NULLABLE"), + ], + dry_run_job=True, + total_bytes_processed=12345, + ) + ) + cursor = connection.cursor() + + cursor.execute( + "SELECT a, b, c FROM hello_world WHERE d > 3;", + job_config=QueryJobConfig(dry_run=True), + ) + + self.assertEqual(cursor.rowcount, 0) + self.assertIsNone(cursor.description) + rows = cursor.fetchall() + self.assertEqual(list(rows), []) + def test_execute_raises_if_result_raises(self): import google.cloud.exceptions @@ -523,8 +620,10 @@ def test_execute_raises_if_result_raises(self): from google.cloud.bigquery.dbapi import exceptions job = mock.create_autospec(job.QueryJob) + job.dry_run = None job.result.side_effect = google.cloud.exceptions.GoogleCloudError("") client = mock.create_autospec(client.Client) + client._default_query_job_config = None client.query.return_value = job connection = connect(client) cursor = connection.cursor()