Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dry run queries with DB API cursor #128

Merged
merged 6 commits into from Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 37 additions & 13 deletions google/cloud/bigquery/dbapi/cursor.py
Expand Up @@ -15,6 +15,7 @@
"""Cursor for the Google BigQuery DB-API."""

import collections
import copy
import warnings

try:
Expand Down Expand Up @@ -93,18 +94,16 @@ def _set_description(self, schema):
return

self.description = tuple(
[
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
]
Column(
name=field.name,
type_code=field.field_type,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=field.is_nullable,
)
for field in schema
)

def _set_rowcount(self, query_results):
Expand Down Expand Up @@ -173,12 +172,24 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None):
formatted_operation = _format_operation(operation, parameters=parameters)
query_parameters = _helpers.to_query_parameters(parameters)

config = job_config or job.QueryJobConfig(use_legacy_sql=False)
if client._default_query_job_config:
if job_config:
config = job_config._fill_from_default(client._default_query_job_config)
else:
config = copy.deepcopy(client._default_query_job_config)
else:
config = job_config or job.QueryJobConfig(use_legacy_sql=False)

config.query_parameters = query_parameters
self._query_job = client.query(
formatted_operation, job_config=config, job_id=job_id
)

if self._query_job.dry_run:
self._set_description(schema=None)
self.rowcount = 0
return

# Wait for the query to finish.
try:
self._query_job.result()
Expand Down Expand Up @@ -211,6 +222,10 @@ def _try_fetch(self, size=None):
"No query results: execute() must be called before fetch."
)

if self._query_job.dry_run:
self._query_data = iter([])
return

is_dml = (
self._query_job.statement_type
and self._query_job.statement_type.upper() != "SELECT"
Expand Down Expand Up @@ -307,6 +322,9 @@ def _bqstorage_fetch(self, bqstorage_client):
def fetchone(self):
"""Fetch a single row from the results of the last ``execute*()`` call.

.. note::
If a dry run query was executed, no rows are returned.

Returns:
Tuple:
A tuple representing a row or ``None`` if no more data is
Expand All @@ -324,6 +342,9 @@ def fetchone(self):
def fetchmany(self, size=None):
"""Fetch multiple results from the last ``execute*()`` call.

.. note::
If a dry run query was executed, no rows are returned.

.. note::
The size parameter is not used for the request/response size.
Set the ``arraysize`` attribute before calling ``execute()`` to
Expand Down Expand Up @@ -360,6 +381,9 @@ def fetchmany(self, size=None):
def fetchall(self):
"""Fetch all remaining results from the last ``execute*()`` call.

.. note::
If a dry run query was executed, no rows are returned.

Returns:
List[Tuple]: A list of all the rows in the results.

Expand Down
16 changes: 16 additions & 0 deletions tests/system.py
Expand Up @@ -1782,6 +1782,22 @@ def test_dbapi_fetch_w_bqstorage_client_v1beta1_large_result_set(self):
]
self.assertEqual(fetched_data, expected_data)

def test_dbapi_dry_run_query(self):
from google.cloud.bigquery.job import QueryJobConfig

query = """
SELECT country_name
FROM `bigquery-public-data.utility_us.country_code_iso`
WHERE country_name LIKE 'U%'
"""

Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True))
self.assertEqual(Config.CURSOR.rowcount, 0, "expected no rows")

rows = Config.CURSOR.fetchall()

self.assertEqual(list(rows), [])

@unittest.skipIf(
bigquery_storage_v1 is None, "Requires `google-cloud-bigquery-storage`"
)
Expand Down
123 changes: 111 additions & 12 deletions tests/unit/test_dbapi_cursor.py
Expand Up @@ -46,7 +46,15 @@ def _get_target_class():
def _make_one(self, *args, **kw):
return self._get_target_class()(*args, **kw)

def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
def _mock_client(
self,
rows=None,
schema=None,
num_dml_affected_rows=None,
default_query_job_config=None,
dry_run_job=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import client

if rows is None:
Expand All @@ -59,8 +67,11 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
dry_run=dry_run_job,
total_bytes_processed=total_bytes_processed,
)
mock_client.list_rows.return_value = rows
mock_client._default_query_job_config = default_query_job_config

# Assure that the REST client gets used, not the BQ Storage client.
mock_client._create_bqstorage_client.return_value = None
Expand Down Expand Up @@ -95,27 +106,41 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0, v1beta1=False):
)

mock_client.create_read_session.return_value = mock_read_session

mock_rows_stream = mock.MagicMock()
mock_rows_stream.rows.return_value = iter(rows)
mock_client.read_rows.return_value = mock_rows_stream

return mock_client

def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None):
def _mock_job(
self,
total_rows=0,
schema=None,
num_dml_affected_rows=None,
dry_run=False,
total_bytes_processed=0,
):
from google.cloud.bigquery import job

mock_job = mock.create_autospec(job.QueryJob)
mock_job.error_result = None
mock_job.state = "DONE"
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)
mock_job.destination.to_bqstorage.return_value = (
"projects/P/datasets/DS/tables/T"
)
mock_job.dry_run = dry_run

if dry_run:
mock_job.result.side_effect = exceptions.NotFound
mock_job.total_bytes_processed = total_bytes_processed
else:
mock_job.result.return_value = mock_job
mock_job._query_results = self._mock_results(
total_rows=total_rows,
schema=schema,
num_dml_affected_rows=num_dml_affected_rows,
)
mock_job.destination.to_bqstorage.return_value = (
"projects/P/datasets/DS/tables/T"
)

if num_dml_affected_rows is None:
mock_job.statement_type = None # API sends back None for SELECT
Expand Down Expand Up @@ -445,7 +470,27 @@ def test_execute_custom_job_id(self):
self.assertEqual(args[0], "SELECT 1;")
self.assertEqual(kwargs["job_id"], "foo")

def test_execute_custom_job_config(self):
def test_execute_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()

cursor.execute("SELECT 1;", job_id="foo")

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=False, flatten_results=True, query_parameters=[]
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_custom_job_config_wo_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

Expand All @@ -459,6 +504,29 @@ def test_execute_custom_job_config(self):
self.assertEqual(kwargs["job_id"], "foo")
self.assertEqual(kwargs["job_config"], config)

def test_execute_custom_job_config_w_default_config(self):
from google.cloud.bigquery.dbapi import connect
from google.cloud.bigquery import job

default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
client = self._mock_client(
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
)
connection = connect(client)
cursor = connection.cursor()
config = job.QueryJobConfig(use_legacy_sql=True)

cursor.execute("SELECT 1;", job_id="foo", job_config=config)

_, kwargs = client.query.call_args
used_config = kwargs["job_config"]
expected_config = job.QueryJobConfig(
use_legacy_sql=True, # the config passed to execute() prevails
flatten_results=True, # from the default
query_parameters=[],
)
self.assertEqual(used_config._properties, expected_config._properties)

def test_execute_w_dml(self):
from google.cloud.bigquery.dbapi import connect

Expand Down Expand Up @@ -514,6 +582,35 @@ def test_execute_w_query(self):
row = cursor.fetchone()
self.assertIsNone(row)

def test_execute_w_query_dry_run(self):
from google.cloud.bigquery.job import QueryJobConfig
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery import dbapi

connection = dbapi.connect(
self._mock_client(
rows=[("hello", "world", 1), ("howdy", "y'all", 2)],
schema=[
SchemaField("a", "STRING", mode="NULLABLE"),
SchemaField("b", "STRING", mode="REQUIRED"),
SchemaField("c", "INTEGER", mode="NULLABLE"),
],
dry_run_job=True,
total_bytes_processed=12345,
)
)
cursor = connection.cursor()

cursor.execute(
"SELECT a, b, c FROM hello_world WHERE d > 3;",
job_config=QueryJobConfig(dry_run=True),
)

self.assertEqual(cursor.rowcount, 0)
self.assertIsNone(cursor.description)
rows = cursor.fetchall()
self.assertEqual(list(rows), [])

def test_execute_raises_if_result_raises(self):
import google.cloud.exceptions

Expand All @@ -523,8 +620,10 @@ def test_execute_raises_if_result_raises(self):
from google.cloud.bigquery.dbapi import exceptions

job = mock.create_autospec(job.QueryJob)
job.dry_run = None
job.result.side_effect = google.cloud.exceptions.GoogleCloudError("")
client = mock.create_autospec(client.Client)
client._default_query_job_config = None
client.query.return_value = job
connection = connect(client)
cursor = connection.cursor()
Expand Down