Skip to content

Commit

Permalink
fix: dry run queries with DB API cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Jun 10, 2020
1 parent 3869e34 commit dac879c
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 10 deletions.
37 changes: 36 additions & 1 deletion google/cloud/bigquery/dbapi/cursor.py
Expand Up @@ -15,6 +15,7 @@
"""Cursor for the Google BigQuery DB-API."""

import collections
import copy

try:
from collections import abc as collections_abc
Expand All @@ -26,6 +27,7 @@
import six

from google.cloud.bigquery import job
from google.cloud.bigquery import table
from google.cloud.bigquery.dbapi import _helpers
from google.cloud.bigquery.dbapi import exceptions
import google.cloud.exceptions
Expand Down Expand Up @@ -169,12 +171,24 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None):
formatted_operation = _format_operation(operation, parameters=parameters)
query_parameters = _helpers.to_query_parameters(parameters)

config = job_config or job.QueryJobConfig(use_legacy_sql=False)
if client._default_query_job_config:
if job_config:
config = job_config._fill_from_default(client._default_query_job_config)
else:
config = copy.deepcopy(client._default_query_job_config)
else:
config = job_config or job.QueryJobConfig(use_legacy_sql=False)

config.query_parameters = query_parameters
self._query_job = client.query(
formatted_operation, job_config=config, job_id=job_id
)

if self._query_job.dry_run:
self.rowcount = 1
self.description = None
return

# Wait for the query to finish.
try:
self._query_job.result()
Expand Down Expand Up @@ -207,6 +221,12 @@ def _try_fetch(self, size=None):
"No query results: execute() must be called before fetch."
)

if self._query_job.dry_run:
estimated_bytes = self._query_job.total_bytes_processed
row = table.Row((estimated_bytes,), {"estimated_bytes": 0})
self._query_data = iter([row])
return

is_dml = (
self._query_job.statement_type
and self._query_job.statement_type.upper() != "SELECT"
Expand Down Expand Up @@ -290,6 +310,11 @@ def _bqstorage_fetch(self, bqstorage_client):
def fetchone(self):
"""Fetch a single row from the results of the last ``execute*()`` call.
.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.
Returns:
Tuple:
A tuple representing a row or ``None`` if no more data is
Expand All @@ -307,6 +332,11 @@ def fetchone(self):
def fetchmany(self, size=None):
"""Fetch multiple results from the last ``execute*()`` call.
.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.
.. note::
The size parameter is not used for the request/response size.
Set the ``arraysize`` attribute before calling ``execute()`` to
Expand Down Expand Up @@ -343,6 +373,11 @@ def fetchmany(self, size=None):
def fetchall(self):
"""Fetch all remaining results from the last ``execute*()`` call.
.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.
Returns:
List[Tuple]: A list of all the rows in the results.
Expand Down
18 changes: 18 additions & 0 deletions tests/system.py
Expand Up @@ -1782,6 +1782,24 @@ def test_dbapi_fetch_w_bqstorage_client_large_result_set(self):
]
self.assertEqual(fetched_data, expected_data)

def test_dbapi_dry_run_query(self):
from google.cloud.bigquery.job import QueryJobConfig

query = """
SELECT country_name
FROM `bigquery-public-data.utility_us.country_code_iso`
WHERE country_name LIKE 'U%'
"""

Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True))
self.assertEqual(Config.CURSOR.rowcount, 1, "expected a single row")

rows = Config.CURSOR.fetchall()

row_tuples = [r.values() for r in rows]
expected = [(3473,)]
self.assertEqual(row_tuples, expected)

def _load_table_for_dml(self, rows, dataset_id, table_id):
from google.cloud._testing import _NamedTemporaryFile
from google.cloud.bigquery.job import CreateDisposition
Expand Down
122 changes: 113 additions & 9 deletions tests/unit/test_dbapi_cursor.py
Expand Up @@ -36,7 +36,15 @@ def _get_target_class():
def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
def _mock_client(
self,
rows=None,
schema=None,
num_dml_affected_rows=None,
default_query_job_config=None,
dry_run_job=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import client

if rows is None:
Expand All @@ -49,8 +57,12 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
dry_run=dry_run_job,
total_bytes_processed=total_bytes_processed,
)
mock_client.list_rows.return_value = rows
mock_client._default_query_job_config = default_query_job_config

return mock_client

def _mock_bqstorage_client(self, rows=None, stream_count=0):
Expand All @@ -76,18 +88,31 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0):

return mock_client

def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None):
def _mock_job(
self,
total_rows=0,
schema=None,
num_dml_affected_rows=None,
dry_run=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import job

mock_job = mock.create_autospec(job.QueryJob)
mock_job.error_result = None
mock_job.state = "DONE"
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)
mock_job.dry_run = dry_run

if dry_run:
mock_job.result.side_effect = exceptions.NotFound
mock_job.total_bytes_processed = total_bytes_processed
else:
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)

if num_dml_affected_rows is None:
mock_job.statement_type = None # API sends back None for SELECT
Expand Down Expand Up @@ -373,7 +398,27 @@ def test_execute_custom_job_id(self):
self.assertEqual(args[0], "SELECT 1;")
self.assertEqual(kwargs["job_id"], "foo")

def test_execute_custom_job_config(self):
def test_execute_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()

cursor.execute("SELECT 1;", job_id="foo")

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=False, flatten_results=True, query_parameters=[]
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_custom_job_config_wo_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

Expand All @@ -387,6 +432,29 @@ def test_execute_custom_job_config(self):
self.assertEqual(kwargs["job_id"], "foo")
self.assertEqual(kwargs["job_config"], config)

def test_execute_custom_job_config_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()
config = job.QueryJobConfig(use_legacy_sql=True)

cursor.execute("SELECT 1;", job_id="foo", job_config=config)

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=True, # the config passed to execute() prevails
flatten_results=True, # from the default
query_parameters=[],
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_w_dml(self):
from google.cloud.bigquery.dbapi import connect

Expand Down Expand Up @@ -442,6 +510,40 @@ def test_execute_w_query(self):
row = cursor.fetchone()
self.assertIsNone(row)

def test_execute_w_query_dry_run(self):
from google.cloud.bigquery.job import QueryJobConfig
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery import dbapi

connection = dbapi.connect(
self._mock_client(
rows=[("hello", "world", 1), ("howdy", "y'all", 2)],
schema=[
SchemaField("a", "STRING", mode="NULLABLE"),
SchemaField("b", "STRING", mode="REQUIRED"),
SchemaField("c", "INTEGER", mode="NULLABLE"),
],
dry_run_job=True,
total_bytes_processed=12345,
)
)
cursor = connection.cursor()
cursor.execute(
"SELECT a, b, c FROM hello_world WHERE d > 3;",
job_config=QueryJobConfig(dry_run=True),
)

self.assertIsNone(cursor.description)
self.assertEqual(cursor.rowcount, 1)

rows = cursor.fetchall()

# We expect a single row with one column - the estimated numbe of bytes
# that will be processed by the query.
self.assertEqual(len(rows), 1)
self.assertEqual(len(rows[0]), 1)
self.assertEqual(rows[0][0], 12345)

def test_execute_raises_if_result_raises(self):
import google.cloud.exceptions

Expand All @@ -451,8 +553,10 @@ def test_execute_raises_if_result_raises(self):
from google.cloud.bigquery.dbapi import exceptions

job = mock.create_autospec(job.QueryJob)
job.dry_run = None
job.result.side_effect = google.cloud.exceptions.GoogleCloudError("")
client = mock.create_autospec(client.Client)
client._default_query_job_config = None
client.query.return_value = job
connection = connect(client)
cursor = connection.cursor()
Expand Down

0 comments on commit dac879c

Please sign in to comment.