Skip to content

Commit

Permalink
fix: dry run queries with DB API cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Jun 10, 2020
1 parent 3869e34 commit 6bbf59e
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 22 deletions.
63 changes: 50 additions & 13 deletions google/cloud/bigquery/dbapi/cursor.py
Expand Up @@ -15,6 +15,7 @@
"""Cursor for the Google BigQuery DB-API."""

import collections
import copy

try:
from collections import abc as collections_abc
Expand All @@ -26,6 +27,8 @@
import six

from google.cloud.bigquery import job
from google.cloud.bigquery import schema
from google.cloud.bigquery import table
from google.cloud.bigquery.dbapi import _helpers
from google.cloud.bigquery.dbapi import exceptions
import google.cloud.exceptions
Expand Down Expand Up @@ -89,18 +92,16 @@ def _set_description(self, schema):
return

self.description = tuple(
[
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
]
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
)

def _set_rowcount(self, query_results):
Expand Down Expand Up @@ -169,12 +170,27 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None):
formatted_operation = _format_operation(operation, parameters=parameters)
query_parameters = _helpers.to_query_parameters(parameters)

config = job_config or job.QueryJobConfig(use_legacy_sql=False)
if client._default_query_job_config:
if job_config:
config = job_config._fill_from_default(client._default_query_job_config)
else:
config = copy.deepcopy(client._default_query_job_config)
else:
config = job_config or job.QueryJobConfig(use_legacy_sql=False)

config.query_parameters = query_parameters
self._query_job = client.query(
formatted_operation, job_config=config, job_id=job_id
)

if self._query_job.dry_run:
schema_field = schema.SchemaField(
name="estimated_bytes", field_type="INTEGER", mode="REQUIRED",
)
self._set_description(schema=[schema_field])
self.rowcount = 1
return

# Wait for the query to finish.
try:
self._query_job.result()
Expand Down Expand Up @@ -207,6 +223,12 @@ def _try_fetch(self, size=None):
"No query results: execute() must be called before fetch."
)

if self._query_job.dry_run:
estimated_bytes = self._query_job.total_bytes_processed
row = table.Row((estimated_bytes,), {"estimated_bytes": 0})
self._query_data = iter([row])
return

is_dml = (
self._query_job.statement_type
and self._query_job.statement_type.upper() != "SELECT"
Expand Down Expand Up @@ -290,6 +312,11 @@ def _bqstorage_fetch(self, bqstorage_client):
def fetchone(self):
"""Fetch a single row from the results of the last ``execute*()`` call.
.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.
Returns:
Tuple:
A tuple representing a row or ``None`` if no more data is
Expand All @@ -307,6 +334,11 @@ def fetchone(self):
def fetchmany(self, size=None):
"""Fetch multiple results from the last ``execute*()`` call.
.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.
.. note::
The size parameter is not used for the request/response size.
Set the ``arraysize`` attribute before calling ``execute()`` to
Expand Down Expand Up @@ -343,6 +375,11 @@ def fetchmany(self, size=None):
def fetchall(self):
"""Fetch all remaining results from the last ``execute*()`` call.
.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.
Returns:
List[Tuple]: A list of all the rows in the results.
Expand Down
18 changes: 18 additions & 0 deletions tests/system.py
Expand Up @@ -1782,6 +1782,24 @@ def test_dbapi_fetch_w_bqstorage_client_large_result_set(self):
]
self.assertEqual(fetched_data, expected_data)

def test_dbapi_dry_run_query(self):
from google.cloud.bigquery.job import QueryJobConfig

query = """
SELECT country_name
FROM `bigquery-public-data.utility_us.country_code_iso`
WHERE country_name LIKE 'U%'
"""

Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True))
self.assertEqual(Config.CURSOR.rowcount, 1, "expected a single row")

rows = Config.CURSOR.fetchall()

row_tuples = [r.values() for r in rows]
expected = [(3473,)]
self.assertEqual(row_tuples, expected)

def _load_table_for_dml(self, rows, dataset_id, table_id):
from google.cloud._testing import _NamedTemporaryFile
from google.cloud.bigquery.job import CreateDisposition
Expand Down
134 changes: 125 additions & 9 deletions tests/unit/test_dbapi_cursor.py
Expand Up @@ -36,7 +36,15 @@ def _get_target_class():
def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
def _mock_client(
self,
rows=None,
schema=None,
num_dml_affected_rows=None,
default_query_job_config=None,
dry_run_job=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import client

if rows is None:
Expand All @@ -49,8 +57,12 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
dry_run=dry_run_job,
total_bytes_processed=total_bytes_processed,
)
mock_client.list_rows.return_value = rows
mock_client._default_query_job_config = default_query_job_config

return mock_client

def _mock_bqstorage_client(self, rows=None, stream_count=0):
Expand All @@ -76,18 +88,31 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0):

return mock_client

def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None):
def _mock_job(
self,
total_rows=0,
schema=None,
num_dml_affected_rows=None,
dry_run=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import job

mock_job = mock.create_autospec(job.QueryJob)
mock_job.error_result = None
mock_job.state = "DONE"
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)
mock_job.dry_run = dry_run

if dry_run:
mock_job.result.side_effect = exceptions.NotFound
mock_job.total_bytes_processed = total_bytes_processed
else:
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)

if num_dml_affected_rows is None:
mock_job.statement_type = None # API sends back None for SELECT
Expand Down Expand Up @@ -373,7 +398,27 @@ def test_execute_custom_job_id(self):
self.assertEqual(args[0], "SELECT 1;")
self.assertEqual(kwargs["job_id"], "foo")

def test_execute_custom_job_config(self):
def test_execute_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()

cursor.execute("SELECT 1;", job_id="foo")

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=False, flatten_results=True, query_parameters=[]
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_custom_job_config_wo_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

Expand All @@ -387,6 +432,29 @@ def test_execute_custom_job_config(self):
self.assertEqual(kwargs["job_id"], "foo")
self.assertEqual(kwargs["job_config"], config)

def test_execute_custom_job_config_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()
config = job.QueryJobConfig(use_legacy_sql=True)

cursor.execute("SELECT 1;", job_id="foo", job_config=config)

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=True, # the config passed to execute() prevails
flatten_results=True, # from the default
query_parameters=[],
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_w_dml(self):
from google.cloud.bigquery.dbapi import connect

Expand Down Expand Up @@ -442,6 +510,52 @@ def test_execute_w_query(self):
row = cursor.fetchone()
self.assertIsNone(row)

def test_execute_w_query_dry_run(self):
from google.cloud.bigquery.job import QueryJobConfig
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery import dbapi

connection = dbapi.connect(
self._mock_client(
rows=[("hello", "world", 1), ("howdy", "y'all", 2)],
schema=[
SchemaField("a", "STRING", mode="NULLABLE"),
SchemaField("b", "STRING", mode="REQUIRED"),
SchemaField("c", "INTEGER", mode="NULLABLE"),
],
dry_run_job=True,
total_bytes_processed=12345,
)
)
cursor = connection.cursor()

cursor.execute(
"SELECT a, b, c FROM hello_world WHERE d > 3;",
job_config=QueryJobConfig(dry_run=True),
)

expected_description = (
dbapi.cursor.Column(
name="estimated_bytes",
type_code="INTEGER",
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=False,
),
)
self.assertEqual(cursor.description, expected_description)
self.assertEqual(cursor.rowcount, 1)

rows = cursor.fetchall()

# We expect a single row with one column - the estimated numbe of bytes
# that will be processed by the query.
self.assertEqual(len(rows), 1)
self.assertEqual(len(rows[0]), 1)
self.assertEqual(rows[0][0], 12345)

def test_execute_raises_if_result_raises(self):
import google.cloud.exceptions

Expand All @@ -451,8 +565,10 @@ def test_execute_raises_if_result_raises(self):
from google.cloud.bigquery.dbapi import exceptions

job = mock.create_autospec(job.QueryJob)
job.dry_run = None
job.result.side_effect = google.cloud.exceptions.GoogleCloudError("")
client = mock.create_autospec(client.Client)
client._default_query_job_config = None
client.query.return_value = job
connection = connect(client)
cursor = connection.cursor()
Expand Down

0 comments on commit 6bbf59e

Please sign in to comment.