Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dry run queries with DB API cursor #128

Merged
merged 6 commits into from Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 50 additions & 13 deletions google/cloud/bigquery/dbapi/cursor.py
Expand Up @@ -15,6 +15,7 @@
"""Cursor for the Google BigQuery DB-API."""

import collections
import copy

try:
from collections import abc as collections_abc
Expand All @@ -26,6 +27,8 @@
import six

from google.cloud.bigquery import job
from google.cloud.bigquery import schema
from google.cloud.bigquery import table
from google.cloud.bigquery.dbapi import _helpers
from google.cloud.bigquery.dbapi import exceptions
import google.cloud.exceptions
Expand Down Expand Up @@ -89,18 +92,16 @@ def _set_description(self, schema):
return

self.description = tuple(
[
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
]
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
)

def _set_rowcount(self, query_results):
Expand Down Expand Up @@ -169,12 +170,27 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None):
formatted_operation = _format_operation(operation, parameters=parameters)
query_parameters = _helpers.to_query_parameters(parameters)

config = job_config or job.QueryJobConfig(use_legacy_sql=False)
if client._default_query_job_config:
if job_config:
config = job_config._fill_from_default(client._default_query_job_config)
else:
config = copy.deepcopy(client._default_query_job_config)
else:
config = job_config or job.QueryJobConfig(use_legacy_sql=False)

config.query_parameters = query_parameters
self._query_job = client.query(
formatted_operation, job_config=config, job_id=job_id
)

if self._query_job.dry_run:
schema_field = schema.SchemaField(
name="estimated_bytes", field_type="INTEGER", mode="REQUIRED",
plamut marked this conversation as resolved.
Show resolved Hide resolved
)
self._set_description(schema=[schema_field])
self.rowcount = 1
return

# Wait for the query to finish.
try:
self._query_job.result()
Expand Down Expand Up @@ -207,6 +223,12 @@ def _try_fetch(self, size=None):
"No query results: execute() must be called before fetch."
)

if self._query_job.dry_run:
estimated_bytes = self._query_job.total_bytes_processed
row = table.Row((estimated_bytes,), {"estimated_bytes": 0})
self._query_data = iter([row])
return

is_dml = (
self._query_job.statement_type
and self._query_job.statement_type.upper() != "SELECT"
Expand Down Expand Up @@ -290,6 +312,11 @@ def _bqstorage_fetch(self, bqstorage_client):
def fetchone(self):
"""Fetch a single row from the results of the last ``execute*()`` call.

.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.

Returns:
Tuple:
A tuple representing a row or ``None`` if no more data is
Expand All @@ -307,6 +334,11 @@ def fetchone(self):
def fetchmany(self, size=None):
"""Fetch multiple results from the last ``execute*()`` call.

.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.

.. note::
The size parameter is not used for the request/response size.
Set the ``arraysize`` attribute before calling ``execute()`` to
Expand Down Expand Up @@ -343,6 +375,11 @@ def fetchmany(self, size=None):
def fetchall(self):
"""Fetch all remaining results from the last ``execute*()`` call.

.. note::
If a dry run query was executed, a row with a single value is
returned representing the estimated number of bytes that would be
processed by the query.

Returns:
List[Tuple]: A list of all the rows in the results.

Expand Down
18 changes: 18 additions & 0 deletions tests/system.py
Expand Up @@ -1782,6 +1782,24 @@ def test_dbapi_fetch_w_bqstorage_client_large_result_set(self):
]
self.assertEqual(fetched_data, expected_data)

def test_dbapi_dry_run_query(self):
from google.cloud.bigquery.job import QueryJobConfig

query = """
SELECT country_name
FROM `bigquery-public-data.utility_us.country_code_iso`
WHERE country_name LIKE 'U%'
"""

Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True))
self.assertEqual(Config.CURSOR.rowcount, 1, "expected a single row")

rows = Config.CURSOR.fetchall()

row_tuples = [r.values() for r in rows]
expected = [(3473,)]
self.assertEqual(row_tuples, expected)

def _load_table_for_dml(self, rows, dataset_id, table_id):
from google.cloud._testing import _NamedTemporaryFile
from google.cloud.bigquery.job import CreateDisposition
Expand Down
134 changes: 125 additions & 9 deletions tests/unit/test_dbapi_cursor.py
Expand Up @@ -36,7 +36,15 @@ def _get_target_class():
def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
def _mock_client(
self,
rows=None,
schema=None,
num_dml_affected_rows=None,
default_query_job_config=None,
dry_run_job=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import client

if rows is None:
Expand All @@ -49,8 +57,12 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
dry_run=dry_run_job,
total_bytes_processed=total_bytes_processed,
)
mock_client.list_rows.return_value = rows
mock_client._default_query_job_config = default_query_job_config

return mock_client

def _mock_bqstorage_client(self, rows=None, stream_count=0):
Expand All @@ -76,18 +88,31 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0):

return mock_client

def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None):
def _mock_job(
self,
total_rows=0,
schema=None,
num_dml_affected_rows=None,
dry_run=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import job

mock_job = mock.create_autospec(job.QueryJob)
mock_job.error_result = None
mock_job.state = "DONE"
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)
mock_job.dry_run = dry_run

if dry_run:
mock_job.result.side_effect = exceptions.NotFound
mock_job.total_bytes_processed = total_bytes_processed
else:
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)

if num_dml_affected_rows is None:
mock_job.statement_type = None # API sends back None for SELECT
Expand Down Expand Up @@ -373,7 +398,27 @@ def test_execute_custom_job_id(self):
self.assertEqual(args[0], "SELECT 1;")
self.assertEqual(kwargs["job_id"], "foo")

def test_execute_custom_job_config(self):
def test_execute_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()

cursor.execute("SELECT 1;", job_id="foo")

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=False, flatten_results=True, query_parameters=[]
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_custom_job_config_wo_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

Expand All @@ -387,6 +432,29 @@ def test_execute_custom_job_config(self):
self.assertEqual(kwargs["job_id"], "foo")
self.assertEqual(kwargs["job_config"], config)

def test_execute_custom_job_config_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()
config = job.QueryJobConfig(use_legacy_sql=True)

cursor.execute("SELECT 1;", job_id="foo", job_config=config)

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=True, # the config passed to execute() prevails
flatten_results=True, # from the default
query_parameters=[],
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_w_dml(self):
from google.cloud.bigquery.dbapi import connect

Expand Down Expand Up @@ -442,6 +510,52 @@ def test_execute_w_query(self):
row = cursor.fetchone()
self.assertIsNone(row)

def test_execute_w_query_dry_run(self):
from google.cloud.bigquery.job import QueryJobConfig
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery import dbapi

connection = dbapi.connect(
self._mock_client(
rows=[("hello", "world", 1), ("howdy", "y'all", 2)],
schema=[
SchemaField("a", "STRING", mode="NULLABLE"),
SchemaField("b", "STRING", mode="REQUIRED"),
SchemaField("c", "INTEGER", mode="NULLABLE"),
],
dry_run_job=True,
total_bytes_processed=12345,
)
)
cursor = connection.cursor()

cursor.execute(
"SELECT a, b, c FROM hello_world WHERE d > 3;",
job_config=QueryJobConfig(dry_run=True),
)

expected_description = (
dbapi.cursor.Column(
name="estimated_bytes",
type_code="INTEGER",
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=False,
),
)
self.assertEqual(cursor.description, expected_description)
self.assertEqual(cursor.rowcount, 1)

rows = cursor.fetchall()

# We expect a single row with one column - the estimated numbe of bytes
# that will be processed by the query.
self.assertEqual(len(rows), 1)
self.assertEqual(len(rows[0]), 1)
self.assertEqual(rows[0][0], 12345)

def test_execute_raises_if_result_raises(self):
import google.cloud.exceptions

Expand All @@ -451,8 +565,10 @@ def test_execute_raises_if_result_raises(self):
from google.cloud.bigquery.dbapi import exceptions

job = mock.create_autospec(job.QueryJob)
job.dry_run = None
job.result.side_effect = google.cloud.exceptions.GoogleCloudError("")
client = mock.create_autospec(client.Client)
client._default_query_job_config = None
client.query.return_value = job
connection = connect(client)
cursor = connection.cursor()
Expand Down