Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix: dry run queries with DB API cursor (#128)
* fix: dry run queries with DB API cursor

* Fix a merge errors with master

* Return no rows on dry run instead of processed bytes count
  • Loading branch information
plamut committed Jun 23, 2020
1 parent 3235255 commit bc33a67
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 25 deletions.
50 changes: 37 additions & 13 deletions google/cloud/bigquery/dbapi/cursor.py
Expand Up @@ -15,6 +15,7 @@
"""Cursor for the Google BigQuery DB-API."""

import collections
import copy
import warnings

try:
Expand Down Expand Up @@ -93,18 +94,16 @@ def _set_description(self, schema):
return

self.description = tuple(
[
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
]
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
)

def _set_rowcount(self, query_results):
Expand Down Expand Up @@ -173,12 +172,24 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None):
formatted_operation = _format_operation(operation, parameters=parameters)
query_parameters = _helpers.to_query_parameters(parameters)

config = job_config or job.QueryJobConfig(use_legacy_sql=False)
if client._default_query_job_config:
if job_config:
config = job_config._fill_from_default(client._default_query_job_config)
else:
config = copy.deepcopy(client._default_query_job_config)
else:
config = job_config or job.QueryJobConfig(use_legacy_sql=False)

config.query_parameters = query_parameters
self._query_job = client.query(
formatted_operation, job_config=config, job_id=job_id
)

if self._query_job.dry_run:
self._set_description(schema=None)
self.rowcount = 0
return

# Wait for the query to finish.
try:
self._query_job.result()
Expand Down Expand Up @@ -211,6 +222,10 @@ def _try_fetch(self, size=None):
"No query results: execute() must be called before fetch."
)

if self._query_job.dry_run:
self._query_data = iter([])
return

is_dml = (
self._query_job.statement_type
and self._query_job.statement_type.upper() != "SELECT"
Expand Down Expand Up @@ -307,6 +322,9 @@ def _bqstorage_fetch(self, bqstorage_client):
def fetchone(self):
"""Fetch a single row from the results of the last ``execute*()`` call.
.. note::
If a dry run query was executed, no rows are returned.
Returns:
Tuple:
A tuple representing a row or ``None`` if no more data is
Expand All @@ -324,6 +342,9 @@ def fetchone(self):
def fetchmany(self, size=None):
"""Fetch multiple results from the last ``execute*()`` call.
.. note::
If a dry run query was executed, no rows are returned.
.. note::
The size parameter is not used for the request/response size.
Set the ``arraysize`` attribute before calling ``execute()`` to
Expand Down Expand Up @@ -360,6 +381,9 @@ def fetchmany(self, size=None):
def fetchall(self):
"""Fetch all remaining results from the last ``execute*()`` call.
.. note::
If a dry run query was executed, no rows are returned.
Returns:
List[Tuple]: A list of all the rows in the results.
Expand Down
16 changes: 16 additions & 0 deletions tests/system.py
Expand Up @@ -1782,6 +1782,22 @@ def test_dbapi_fetch_w_bqstorage_client_v1beta1_large_result_set(self):
]
self.assertEqual(fetched_data, expected_data)

def test_dbapi_dry_run_query(self):
from google.cloud.bigquery.job import QueryJobConfig

query = """
SELECT country_name
FROM `bigquery-public-data.utility_us.country_code_iso`
WHERE country_name LIKE 'U%'
"""

Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True))
self.assertEqual(Config.CURSOR.rowcount, 0, "expected no rows")

rows = Config.CURSOR.fetchall()

self.assertEqual(list(rows), [])

@unittest.skipIf(
bigquery_storage_v1 is None, "Requires `google-cloud-bigquery-storage`"
)
Expand Down
123 changes: 111 additions & 12 deletions tests/unit/test_dbapi_cursor.py
Expand Up @@ -46,7 +46,15 @@ def _get_target_class():
def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
def _mock_client(
self,
rows=None,
schema=None,
num_dml_affected_rows=None,
default_query_job_config=None,
dry_run_job=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import client

if rows is None:
Expand All @@ -59,8 +67,11 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
dry_run=dry_run_job,
total_bytes_processed=total_bytes_processed,
)
mock_client.list_rows.return_value = rows
mock_client._default_query_job_config = default_query_job_config

# Assure that the REST client gets used, not the BQ Storage client.
mock_client._create_bqstorage_client.return_value = None
Expand Down Expand Up @@ -95,27 +106,41 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0, v1beta1=False):
)

mock_client.create_read_session.return_value = mock_read_session

mock_rows_stream = mock.MagicMock()
mock_rows_stream.rows.return_value = iter(rows)
mock_client.read_rows.return_value = mock_rows_stream

return mock_client

def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None):
def _mock_job(
self,
total_rows=0,
schema=None,
num_dml_affected_rows=None,
dry_run=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import job

mock_job = mock.create_autospec(job.QueryJob)
mock_job.error_result = None
mock_job.state = "DONE"
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)
mock_job.destination.to_bqstorage.return_value = (
"projects/P/datasets/DS/tables/T"
)
mock_job.dry_run = dry_run

if dry_run:
mock_job.result.side_effect = exceptions.NotFound
mock_job.total_bytes_processed = total_bytes_processed
else:
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)
mock_job.destination.to_bqstorage.return_value = (
"projects/P/datasets/DS/tables/T"
)

if num_dml_affected_rows is None:
mock_job.statement_type = None # API sends back None for SELECT
Expand Down Expand Up @@ -445,7 +470,27 @@ def test_execute_custom_job_id(self):
self.assertEqual(args[0], "SELECT 1;")
self.assertEqual(kwargs["job_id"], "foo")

def test_execute_custom_job_config(self):
def test_execute_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()

cursor.execute("SELECT 1;", job_id="foo")

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=False, flatten_results=True, query_parameters=[]
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_custom_job_config_wo_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

Expand All @@ -459,6 +504,29 @@ def test_execute_custom_job_config(self):
self.assertEqual(kwargs["job_id"], "foo")
self.assertEqual(kwargs["job_config"], config)

def test_execute_custom_job_config_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()
config = job.QueryJobConfig(use_legacy_sql=True)

cursor.execute("SELECT 1;", job_id="foo", job_config=config)

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=True, # the config passed to execute() prevails
flatten_results=True, # from the default
query_parameters=[],
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_w_dml(self):
from google.cloud.bigquery.dbapi import connect

Expand Down Expand Up @@ -514,6 +582,35 @@ def test_execute_w_query(self):
row = cursor.fetchone()
self.assertIsNone(row)

def test_execute_w_query_dry_run(self):
from google.cloud.bigquery.job import QueryJobConfig
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery import dbapi

connection = dbapi.connect(
self._mock_client(
rows=[("hello", "world", 1), ("howdy", "y'all", 2)],
schema=[
SchemaField("a", "STRING", mode="NULLABLE"),
SchemaField("b", "STRING", mode="REQUIRED"),
SchemaField("c", "INTEGER", mode="NULLABLE"),
],
dry_run_job=True,
total_bytes_processed=12345,
)
)
cursor = connection.cursor()

cursor.execute(
"SELECT a, b, c FROM hello_world WHERE d > 3;",
job_config=QueryJobConfig(dry_run=True),
)

self.assertEqual(cursor.rowcount, 0)
self.assertIsNone(cursor.description)
rows = cursor.fetchall()
self.assertEqual(list(rows), [])

def test_execute_raises_if_result_raises(self):
import google.cloud.exceptions

Expand All @@ -523,8 +620,10 @@ def test_execute_raises_if_result_raises(self):
from google.cloud.bigquery.dbapi import exceptions

job = mock.create_autospec(job.QueryJob)
job.dry_run = None
job.result.side_effect = google.cloud.exceptions.GoogleCloudError("")
client = mock.create_autospec(client.Client)
client._default_query_job_config = None
client.query.return_value = job
connection = connect(client)
cursor = connection.cursor()
Expand Down

0 comments on commit bc33a67

Please sign in to comment.