Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bigquery): add support of model for extract job #71

Merged
merged 8 commits into from May 11, 2020
38 changes: 31 additions & 7 deletions google/cloud/bigquery/client.py
Expand Up @@ -65,6 +65,7 @@
from google.cloud.bigquery import job
from google.cloud.bigquery.model import Model
from google.cloud.bigquery.model import ModelReference
from google.cloud.bigquery.model import _model_arg_to_model_ref
from google.cloud.bigquery.query import _QueryResults
from google.cloud.bigquery.retry import DEFAULT_RETRY
from google.cloud.bigquery.routine import Routine
Expand Down Expand Up @@ -1364,9 +1365,17 @@ def create_job(self, job_config, retry=DEFAULT_RETRY):
job_config
)
source = _get_sub_prop(job_config, ["extract", "sourceTable"])
source_type = "Table"
if not source:
source = _get_sub_prop(job_config, ["extract", "sourceModel"])
source_type = "Model"
destination_uris = _get_sub_prop(job_config, ["extract", "destinationUris"])
return self.extract_table(
source, destination_uris, job_config=extract_job_config, retry=retry
source,
destination_uris,
job_config=extract_job_config,
retry=retry,
source_type=source_type,
)
elif "query" in job_config:
copy_config = copy.deepcopy(job_config)
Expand Down Expand Up @@ -2282,6 +2291,7 @@ def extract_table(
job_config=None,
retry=DEFAULT_RETRY,
timeout=None,
source_type="Table",
):
"""Start a job to extract a table into Cloud Storage files.

Expand All @@ -2292,9 +2302,11 @@ def extract_table(
source (Union[ \
google.cloud.bigquery.table.Table, \
google.cloud.bigquery.table.TableReference, \
google.cloud.bigquery.model.Model, \
google.cloud.bigquery.model.ModelReference, \
src, \
]):
Table to be extracted.
Table or Model to be extracted.
destination_uris (Union[str, Sequence[str]]):
URIs of Cloud Storage file(s) into which table data is to be
extracted; in format
Expand All @@ -2319,17 +2331,19 @@ def extract_table(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
Args:
source (google.cloud.bigquery.table.TableReference): table to be extracted.

source_type (str):
(Optional) Type of source to be extracted.``Table`` or ``Model``.
Defaults to ``Table``.
Returns:
google.cloud.bigquery.job.ExtractJob: A new extract job instance.

Raises:
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.ExtractJobConfig`
class.
"""
ValueError:
If ``source_type`` is not among ``Table``,``Model``.
"""
job_id = _make_job_id(job_id, job_id_prefix)

if project is None:
Expand All @@ -2339,7 +2353,17 @@ def extract_table(
location = self.location

job_ref = job._JobReference(job_id, project=project, location=location)
source = _table_arg_to_table_ref(source, default_project=self.project)
src = source_type.lower()
if src == "table":
source = _table_arg_to_table_ref(source, default_project=self.project)
elif src == "model":
source = _model_arg_to_model_ref(source, default_project=self.project)
else:
raise ValueError(
"Cannot pass `{}` as a ``source_type``, pass Table or Model".format(
source_type
)
)

if isinstance(destination_uris, six.string_types):
destination_uris = [destination_uris]
Expand Down
37 changes: 28 additions & 9 deletions google/cloud/bigquery/job.py
Expand Up @@ -1990,8 +1990,11 @@ class ExtractJob(_AsyncJob):
Args:
job_id (str): the job's ID.

source (google.cloud.bigquery.table.TableReference):
Table into which data is to be loaded.
source (Union[ \
google.cloud.bigquery.table.TableReference, \
google.cloud.bigquery.model.ModelReference \
]):
Table or Model from which data is to be loaded.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe s/loaded/extracted/ in the comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shollyman We don't have any model lifecycle test in the integration tests.


destination_uris (List[str]):
URIs describing where the extracted data will be written in Cloud
Expand Down Expand Up @@ -2067,14 +2070,20 @@ def destination_uri_file_counts(self):
def to_api_repr(self):
"""Generate a resource for :meth:`_begin`."""

configuration = self._configuration.to_api_repr()
source_ref = {
"projectId": self.source.project,
"datasetId": self.source.dataset_id,
"tableId": self.source.table_id,
}

configuration = self._configuration.to_api_repr()
_helpers._set_sub_prop(configuration, ["extract", "sourceTable"], source_ref)
source = "sourceTable"
if isinstance(self.source, TableReference):
source_ref["tableId"] = self.source.table_id
else:
source_ref["modelId"] = self.source.model_id
source = "sourceModel"

_helpers._set_sub_prop(configuration, ["extract", source], source_ref)
_helpers._set_sub_prop(
configuration, ["extract", "destinationUris"], self.destination_uris
)
Expand Down Expand Up @@ -2112,10 +2121,20 @@ def from_api_repr(cls, resource, client):
source_config = _helpers._get_sub_prop(
config_resource, ["extract", "sourceTable"]
)
dataset = DatasetReference(
source_config["projectId"], source_config["datasetId"]
)
source = dataset.table(source_config["tableId"])
if source_config:
dataset = DatasetReference(
source_config["projectId"], source_config["datasetId"]
)
source = dataset.table(source_config["tableId"])
else:
source_config = _helpers._get_sub_prop(
config_resource, ["extract", "sourceModel"]
)
dataset = DatasetReference(
source_config["projectId"], source_config["datasetId"]
)
source = dataset.model(source_config["modelId"])

destination_uris = _helpers._get_sub_prop(
config_resource, ["extract", "destinationUris"]
)
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/bigquery/model.py
Expand Up @@ -433,3 +433,15 @@ def __repr__(self):
return "ModelReference(project_id='{}', dataset_id='{}', model_id='{}')".format(
self.project, self.dataset_id, self.model_id
)


def _model_arg_to_model_ref(value, default_project=None):
"""Helper to convert a string or Model to ModelReference.

This function keeps ModelReference and other kinds of objects unchanged.
"""
if isinstance(value, six.string_types):
return ModelReference.from_string(value, default_project=default_project)
if isinstance(value, Model):
return value.reference
return value
149 changes: 149 additions & 0 deletions tests/unit/test_client.py
Expand Up @@ -2884,6 +2884,21 @@ def test_create_job_extract_config(self):
configuration, "google.cloud.bigquery.client.Client.extract_table",
)

def test_create_job_extract_config_for_model(self):
configuration = {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": "source_model",
},
"destinationUris": ["gs://test_bucket/dst_object*"],
}
}
self._create_job_helper(
configuration, "google.cloud.bigquery.client.Client.extract_table",
)
plamut marked this conversation as resolved.
Show resolved Hide resolved

def test_create_job_query_config(self):
configuration = {
"query": {"query": "query", "destinationTable": {"tableId": "table_id"}}
Expand Down Expand Up @@ -4217,6 +4232,140 @@ def test_extract_table_w_destination_uris(self):
self.assertEqual(job.source, source)
self.assertEqual(list(job.destination_uris), [DESTINATION1, DESTINATION2])

def test_extract_table_for_source_type_model(self):
from google.cloud.bigquery.job import ExtractJob

JOB = "job_id"
SOURCE = "source_model"
DESTINATION = "gs://bucket_name/object_name"
RESOURCE = {
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": SOURCE,
},
"destinationUris": [DESTINATION],
}
},
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection(RESOURCE)
dataset = DatasetReference(self.PROJECT, self.DS_ID)
source = dataset.model(SOURCE)

job = client.extract_table(
source, DESTINATION, job_id=JOB, timeout=7.5, source_type="Model"
)

# Check that extract_table actually starts the job.
conn.api_request.assert_called_once_with(
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
)

# Check the job resource.
self.assertIsInstance(job, ExtractJob)
self.assertIs(job._client, client)
self.assertEqual(job.job_id, JOB)
self.assertEqual(job.source, source)
self.assertEqual(list(job.destination_uris), [DESTINATION])

def test_extract_table_for_source_type_model_w_string_model_id(self):
JOB = "job_id"
source_id = "source_model"
DESTINATION = "gs://bucket_name/object_name"
RESOURCE = {
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": source_id,
},
"destinationUris": [DESTINATION],
}
},
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection(RESOURCE)

client.extract_table(
# Test with string for model ID.
"{}.{}".format(self.DS_ID, source_id),
DESTINATION,
job_id=JOB,
timeout=7.5,
source_type="Model",
)

# Check that extract_table actually starts the job.
conn.api_request.assert_called_once_with(
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
)

def test_extract_table_for_source_type_model_w_model_object(self):
from google.cloud.bigquery.model import Model

JOB = "job_id"
DESTINATION = "gs://bucket_name/object_name"
model_id = "{}.{}.{}".format(self.PROJECT, self.DS_ID, self.MODEL_ID)
model = Model(model_id)
RESOURCE = {
"jobReference": {"projectId": self.PROJECT, "jobId": JOB},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": self.MODEL_ID,
},
"destinationUris": [DESTINATION],
}
},
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection(RESOURCE)

client.extract_table(
# Test with Model class object.
model,
DESTINATION,
job_id=JOB,
timeout=7.5,
source_type="Model",
)

# Check that extract_table actually starts the job.
conn.api_request.assert_called_once_with(
method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5,
)

def test_extract_table_for_invalid_source_type_model(self):
JOB = "job_id"
SOURCE = "source_model"
DESTINATION = "gs://bucket_name/object_name"
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
dataset = DatasetReference(self.PROJECT, self.DS_ID)
source = dataset.model(SOURCE)

with self.assertRaises(ValueError) as exc:
client.extract_table(
source, DESTINATION, job_id=JOB, timeout=7.5, source_type="foo"
)

self.assertIn("Cannot pass", exc.exception.args[0])

def test_query_defaults(self):
from google.cloud.bigquery.job import QueryJob

Expand Down
36 changes: 32 additions & 4 deletions tests/unit/test_job.py
Expand Up @@ -3176,10 +3176,16 @@ def _verifyResourceProperties(self, job, resource):

self.assertEqual(job.destination_uris, config["destinationUris"])

table_ref = config["sourceTable"]
self.assertEqual(job.source.project, table_ref["projectId"])
self.assertEqual(job.source.dataset_id, table_ref["datasetId"])
self.assertEqual(job.source.table_id, table_ref["tableId"])
if "sourceTable" in config:
table_ref = config["sourceTable"]
self.assertEqual(job.source.project, table_ref["projectId"])
self.assertEqual(job.source.dataset_id, table_ref["datasetId"])
self.assertEqual(job.source.table_id, table_ref["tableId"])
else:
model_ref = config["sourceModel"]
self.assertEqual(job.source.project, model_ref["projectId"])
self.assertEqual(job.source.dataset_id, model_ref["datasetId"])
self.assertEqual(job.source.model_id, model_ref["modelId"])

if "compression" in config:
self.assertEqual(job.compression, config["compression"])
Expand Down Expand Up @@ -3281,6 +3287,28 @@ def test_from_api_repr_bare(self):
self.assertIs(job._client, client)
self._verifyResourceProperties(job, RESOURCE)

def test_from_api_repr_for_model(self):
self._setUpConstants()
client = _make_client(project=self.PROJECT)
RESOURCE = {
"id": self.JOB_ID,
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"configuration": {
"extract": {
"sourceModel": {
"projectId": self.PROJECT,
"datasetId": self.DS_ID,
"modelId": "model_id",
},
"destinationUris": [self.DESTINATION_URI],
}
},
}
klass = self._get_target_class()
job = klass.from_api_repr(RESOURCE, client=client)
self.assertIs(job._client, client)
self._verifyResourceProperties(job, RESOURCE)

def test_from_api_repr_w_properties(self):
from google.cloud.bigquery.job import Compression

Expand Down