Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: accept DatasetListItem where DatasetReference is accepted #597

Merged
merged 20 commits into from Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion google/cloud/bigquery/client.py
Expand Up @@ -1321,7 +1321,13 @@ def list_tables(
)

if not isinstance(dataset, (Dataset, DatasetReference)):
raise TypeError("dataset must be a Dataset, DatasetReference, or string")
if isinstance(dataset, DatasetListItem):
dataset = dataset.reference
else:
raise TypeError(
"dataset must be a Dataset, DatasetReference, DatasetListItem,"
" or string"
)

path = "%s/tables" % dataset.path
span_attributes = {"path": path}
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/conftest.py
@@ -0,0 +1,18 @@
import pytest

from .helpers import make_client


@pytest.fixture
def client():
yield make_client()


@pytest.fixture
def PROJECT():
yield "PROJECT"


@pytest.fixture
def DS_ID():
yield "DATASET_ID"
41 changes: 41 additions & 0 deletions tests/unit/helpers.py
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import google.cloud.bigquery.client
import google.cloud.bigquery.dataset
import mock
import pytest


def make_connection(*responses):
import google.cloud.bigquery._http
Expand All @@ -31,3 +36,39 @@ def _to_pyarrow(value):
import pyarrow

return pyarrow.array([value])[0]


def make_client(project="PROJECT", **kw):
credentials = mock.Mock(spec=google.auth.credentials.Credentials)
return google.cloud.bigquery.client.Client(project, credentials, **kw)


def make_dataset(project, ds_id):
return google.cloud.bigquery.dataset.Dataset(
google.cloud.bigquery.dataset.DatasetReference(project, ds_id)
)


def make_dataset_list_item(project, ds_id):
return google.cloud.bigquery.dataset.DatasetListItem(
dict(datasetReference=dict(projectId=project, datasetId=ds_id))
)


def identity(x):
return x


def get_reference(x):
return x.reference


dataset_like = [
(google.cloud.bigquery.dataset.DatasetReference, identity),
(make_dataset, identity),
(make_dataset_list_item, get_reference),
]

dataset_polymorphic = pytest.mark.parametrize(
"make_dataset,get_reference", dataset_like
)
157 changes: 0 additions & 157 deletions tests/unit/test_client.py
Expand Up @@ -2926,30 +2926,6 @@ def test_update_table_delete_property(self):
self.assertEqual(req[1]["data"], sent)
self.assertIsNone(table3.description)

def test_list_tables_empty_w_timeout(self):
path = "/projects/{}/datasets/{}/tables".format(self.PROJECT, self.DS_ID)
creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
conn = client._connection = make_connection({})

dataset = DatasetReference(self.PROJECT, self.DS_ID)
iterator = client.list_tables(dataset, timeout=7.5)
self.assertIs(iterator.dataset, dataset)
with mock.patch(
"google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes"
) as final_attributes:
page = next(iterator.pages)

final_attributes.assert_called_once_with({"path": path}, client, None)
tables = list(page)
token = iterator.next_page_token

self.assertEqual(tables, [])
self.assertIsNone(token)
conn.api_request.assert_called_once_with(
method="GET", path=path, query_params={}, timeout=7.5
)

def test_list_models_empty_w_timeout(self):
path = "/projects/{}/datasets/{}/models".format(self.PROJECT, self.DS_ID)
creds = _make_credentials()
Expand Down Expand Up @@ -3125,139 +3101,6 @@ def test_list_routines_wrong_type(self):
DatasetReference(self.PROJECT, self.DS_ID).table("foo")
)

def test_list_tables_defaults(self):
from google.cloud.bigquery.table import TableListItem

TABLE_1 = "table_one"
TABLE_2 = "table_two"
PATH = "projects/%s/datasets/%s/tables" % (self.PROJECT, self.DS_ID)
TOKEN = "TOKEN"
DATA = {
"nextPageToken": TOKEN,
"tables": [
{
"kind": "bigquery#table",
"id": "%s:%s.%s" % (self.PROJECT, self.DS_ID, TABLE_1),
"tableReference": {
"tableId": TABLE_1,
"datasetId": self.DS_ID,
"projectId": self.PROJECT,
},
"type": "TABLE",
},
{
"kind": "bigquery#table",
"id": "%s:%s.%s" % (self.PROJECT, self.DS_ID, TABLE_2),
"tableReference": {
"tableId": TABLE_2,
"datasetId": self.DS_ID,
"projectId": self.PROJECT,
},
"type": "TABLE",
},
],
}

creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
conn = client._connection = make_connection(DATA)
dataset = DatasetReference(self.PROJECT, self.DS_ID)

iterator = client.list_tables(dataset)
self.assertIs(iterator.dataset, dataset)
with mock.patch(
"google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes"
) as final_attributes:
page = next(iterator.pages)

final_attributes.assert_called_once_with({"path": "/%s" % PATH}, client, None)
tables = list(page)
token = iterator.next_page_token

self.assertEqual(len(tables), len(DATA["tables"]))
for found, expected in zip(tables, DATA["tables"]):
self.assertIsInstance(found, TableListItem)
self.assertEqual(found.full_table_id, expected["id"])
self.assertEqual(found.table_type, expected["type"])
self.assertEqual(token, TOKEN)

conn.api_request.assert_called_once_with(
method="GET", path="/%s" % PATH, query_params={}, timeout=None
)

def test_list_tables_explicit(self):
from google.cloud.bigquery.table import TableListItem

TABLE_1 = "table_one"
TABLE_2 = "table_two"
PATH = "projects/%s/datasets/%s/tables" % (self.PROJECT, self.DS_ID)
TOKEN = "TOKEN"
DATA = {
"tables": [
{
"kind": "bigquery#dataset",
"id": "%s:%s.%s" % (self.PROJECT, self.DS_ID, TABLE_1),
"tableReference": {
"tableId": TABLE_1,
"datasetId": self.DS_ID,
"projectId": self.PROJECT,
},
"type": "TABLE",
},
{
"kind": "bigquery#dataset",
"id": "%s:%s.%s" % (self.PROJECT, self.DS_ID, TABLE_2),
"tableReference": {
"tableId": TABLE_2,
"datasetId": self.DS_ID,
"projectId": self.PROJECT,
},
"type": "TABLE",
},
]
}

creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
conn = client._connection = make_connection(DATA)
dataset = DatasetReference(self.PROJECT, self.DS_ID)

iterator = client.list_tables(
# Test with string for dataset ID.
self.DS_ID,
max_results=3,
page_token=TOKEN,
)
self.assertEqual(iterator.dataset, dataset)
with mock.patch(
"google.cloud.bigquery.opentelemetry_tracing._get_final_span_attributes"
) as final_attributes:
page = next(iterator.pages)

final_attributes.assert_called_once_with({"path": "/%s" % PATH}, client, None)
tables = list(page)
token = iterator.next_page_token

self.assertEqual(len(tables), len(DATA["tables"]))
for found, expected in zip(tables, DATA["tables"]):
self.assertIsInstance(found, TableListItem)
self.assertEqual(found.full_table_id, expected["id"])
self.assertEqual(found.table_type, expected["type"])
self.assertIsNone(token)

conn.api_request.assert_called_once_with(
method="GET",
path="/%s" % PATH,
query_params={"maxResults": 3, "pageToken": TOKEN},
timeout=None,
)

def test_list_tables_wrong_type(self):
creds = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=creds)
with self.assertRaises(TypeError):
client.list_tables(DatasetReference(self.PROJECT, self.DS_ID).table("foo"))

def test_delete_dataset(self):
from google.cloud.bigquery.dataset import Dataset
from google.cloud.bigquery.dataset import DatasetReference
Expand Down