diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index a9c77d5e1..da5b30a35 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -65,6 +65,7 @@ from google.cloud.bigquery import job from google.cloud.bigquery.model import Model from google.cloud.bigquery.model import ModelReference +from google.cloud.bigquery.model import _model_arg_to_model_ref from google.cloud.bigquery.query import _QueryResults from google.cloud.bigquery.retry import DEFAULT_RETRY from google.cloud.bigquery.routine import Routine @@ -1364,9 +1365,17 @@ def create_job(self, job_config, retry=DEFAULT_RETRY): job_config ) source = _get_sub_prop(job_config, ["extract", "sourceTable"]) + source_type = "Table" + if not source: + source = _get_sub_prop(job_config, ["extract", "sourceModel"]) + source_type = "Model" destination_uris = _get_sub_prop(job_config, ["extract", "destinationUris"]) return self.extract_table( - source, destination_uris, job_config=extract_job_config, retry=retry + source, + destination_uris, + job_config=extract_job_config, + retry=retry, + source_type=source_type, ) elif "query" in job_config: copy_config = copy.deepcopy(job_config) @@ -2282,6 +2291,7 @@ def extract_table( job_config=None, retry=DEFAULT_RETRY, timeout=None, + source_type="Table", ): """Start a job to extract a table into Cloud Storage files. @@ -2292,9 +2302,11 @@ def extract_table( source (Union[ \ google.cloud.bigquery.table.Table, \ google.cloud.bigquery.table.TableReference, \ + google.cloud.bigquery.model.Model, \ + google.cloud.bigquery.model.ModelReference, \ src, \ ]): - Table to be extracted. + Table or Model to be extracted. destination_uris (Union[str, Sequence[str]]): URIs of Cloud Storage file(s) into which table data is to be extracted; in format @@ -2319,9 +2331,9 @@ def extract_table( timeout (Optional[float]): The number of seconds to wait for the underlying HTTP transport before using ``retry``. - Args: - source (google.cloud.bigquery.table.TableReference): table to be extracted. - + source_type (str): + (Optional) Type of source to be extracted.``Table`` or ``Model``. + Defaults to ``Table``. Returns: google.cloud.bigquery.job.ExtractJob: A new extract job instance. @@ -2329,7 +2341,9 @@ def extract_table( TypeError: If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.ExtractJobConfig` class. - """ + ValueError: + If ``source_type`` is not among ``Table``,``Model``. + """ job_id = _make_job_id(job_id, job_id_prefix) if project is None: @@ -2339,7 +2353,17 @@ def extract_table( location = self.location job_ref = job._JobReference(job_id, project=project, location=location) - source = _table_arg_to_table_ref(source, default_project=self.project) + src = source_type.lower() + if src == "table": + source = _table_arg_to_table_ref(source, default_project=self.project) + elif src == "model": + source = _model_arg_to_model_ref(source, default_project=self.project) + else: + raise ValueError( + "Cannot pass `{}` as a ``source_type``, pass Table or Model".format( + source_type + ) + ) if isinstance(destination_uris, six.string_types): destination_uris = [destination_uris] diff --git a/google/cloud/bigquery/job.py b/google/cloud/bigquery/job.py index 4f3103bb5..25dd446e8 100644 --- a/google/cloud/bigquery/job.py +++ b/google/cloud/bigquery/job.py @@ -1990,8 +1990,11 @@ class ExtractJob(_AsyncJob): Args: job_id (str): the job's ID. - source (google.cloud.bigquery.table.TableReference): - Table into which data is to be loaded. + source (Union[ \ + google.cloud.bigquery.table.TableReference, \ + google.cloud.bigquery.model.ModelReference \ + ]): + Table or Model from which data is to be loaded or extracted. destination_uris (List[str]): URIs describing where the extracted data will be written in Cloud @@ -2067,14 +2070,20 @@ def destination_uri_file_counts(self): def to_api_repr(self): """Generate a resource for :meth:`_begin`.""" + configuration = self._configuration.to_api_repr() source_ref = { "projectId": self.source.project, "datasetId": self.source.dataset_id, - "tableId": self.source.table_id, } - configuration = self._configuration.to_api_repr() - _helpers._set_sub_prop(configuration, ["extract", "sourceTable"], source_ref) + source = "sourceTable" + if isinstance(self.source, TableReference): + source_ref["tableId"] = self.source.table_id + else: + source_ref["modelId"] = self.source.model_id + source = "sourceModel" + + _helpers._set_sub_prop(configuration, ["extract", source], source_ref) _helpers._set_sub_prop( configuration, ["extract", "destinationUris"], self.destination_uris ) @@ -2112,10 +2121,20 @@ def from_api_repr(cls, resource, client): source_config = _helpers._get_sub_prop( config_resource, ["extract", "sourceTable"] ) - dataset = DatasetReference( - source_config["projectId"], source_config["datasetId"] - ) - source = dataset.table(source_config["tableId"]) + if source_config: + dataset = DatasetReference( + source_config["projectId"], source_config["datasetId"] + ) + source = dataset.table(source_config["tableId"]) + else: + source_config = _helpers._get_sub_prop( + config_resource, ["extract", "sourceModel"] + ) + dataset = DatasetReference( + source_config["projectId"], source_config["datasetId"] + ) + source = dataset.model(source_config["modelId"]) + destination_uris = _helpers._get_sub_prop( config_resource, ["extract", "destinationUris"] ) diff --git a/google/cloud/bigquery/model.py b/google/cloud/bigquery/model.py index a2510e86c..eb459f57a 100644 --- a/google/cloud/bigquery/model.py +++ b/google/cloud/bigquery/model.py @@ -433,3 +433,15 @@ def __repr__(self): return "ModelReference(project_id='{}', dataset_id='{}', model_id='{}')".format( self.project, self.dataset_id, self.model_id ) + + +def _model_arg_to_model_ref(value, default_project=None): + """Helper to convert a string or Model to ModelReference. + + This function keeps ModelReference and other kinds of objects unchanged. + """ + if isinstance(value, six.string_types): + return ModelReference.from_string(value, default_project=default_project) + if isinstance(value, Model): + return value.reference + return value diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index fddfa4b1b..6edb2e168 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -2884,6 +2884,21 @@ def test_create_job_extract_config(self): configuration, "google.cloud.bigquery.client.Client.extract_table", ) + def test_create_job_extract_config_for_model(self): + configuration = { + "extract": { + "sourceModel": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": "source_model", + }, + "destinationUris": ["gs://test_bucket/dst_object*"], + } + } + self._create_job_helper( + configuration, "google.cloud.bigquery.client.Client.extract_table", + ) + def test_create_job_query_config(self): configuration = { "query": {"query": "query", "destinationTable": {"tableId": "table_id"}} @@ -4217,6 +4232,140 @@ def test_extract_table_w_destination_uris(self): self.assertEqual(job.source, source) self.assertEqual(list(job.destination_uris), [DESTINATION1, DESTINATION2]) + def test_extract_table_for_source_type_model(self): + from google.cloud.bigquery.job import ExtractJob + + JOB = "job_id" + SOURCE = "source_model" + DESTINATION = "gs://bucket_name/object_name" + RESOURCE = { + "jobReference": {"projectId": self.PROJECT, "jobId": JOB}, + "configuration": { + "extract": { + "sourceModel": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": SOURCE, + }, + "destinationUris": [DESTINATION], + } + }, + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(RESOURCE) + dataset = DatasetReference(self.PROJECT, self.DS_ID) + source = dataset.model(SOURCE) + + job = client.extract_table( + source, DESTINATION, job_id=JOB, timeout=7.5, source_type="Model" + ) + + # Check that extract_table actually starts the job. + conn.api_request.assert_called_once_with( + method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5, + ) + + # Check the job resource. + self.assertIsInstance(job, ExtractJob) + self.assertIs(job._client, client) + self.assertEqual(job.job_id, JOB) + self.assertEqual(job.source, source) + self.assertEqual(list(job.destination_uris), [DESTINATION]) + + def test_extract_table_for_source_type_model_w_string_model_id(self): + JOB = "job_id" + source_id = "source_model" + DESTINATION = "gs://bucket_name/object_name" + RESOURCE = { + "jobReference": {"projectId": self.PROJECT, "jobId": JOB}, + "configuration": { + "extract": { + "sourceModel": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": source_id, + }, + "destinationUris": [DESTINATION], + } + }, + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(RESOURCE) + + client.extract_table( + # Test with string for model ID. + "{}.{}".format(self.DS_ID, source_id), + DESTINATION, + job_id=JOB, + timeout=7.5, + source_type="Model", + ) + + # Check that extract_table actually starts the job. + conn.api_request.assert_called_once_with( + method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5, + ) + + def test_extract_table_for_source_type_model_w_model_object(self): + from google.cloud.bigquery.model import Model + + JOB = "job_id" + DESTINATION = "gs://bucket_name/object_name" + model_id = "{}.{}.{}".format(self.PROJECT, self.DS_ID, self.MODEL_ID) + model = Model(model_id) + RESOURCE = { + "jobReference": {"projectId": self.PROJECT, "jobId": JOB}, + "configuration": { + "extract": { + "sourceModel": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": self.MODEL_ID, + }, + "destinationUris": [DESTINATION], + } + }, + } + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection(RESOURCE) + + client.extract_table( + # Test with Model class object. + model, + DESTINATION, + job_id=JOB, + timeout=7.5, + source_type="Model", + ) + + # Check that extract_table actually starts the job. + conn.api_request.assert_called_once_with( + method="POST", path="/projects/PROJECT/jobs", data=RESOURCE, timeout=7.5, + ) + + def test_extract_table_for_invalid_source_type_model(self): + JOB = "job_id" + SOURCE = "source_model" + DESTINATION = "gs://bucket_name/object_name" + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + dataset = DatasetReference(self.PROJECT, self.DS_ID) + source = dataset.model(SOURCE) + + with self.assertRaises(ValueError) as exc: + client.extract_table( + source, DESTINATION, job_id=JOB, timeout=7.5, source_type="foo" + ) + + self.assertIn("Cannot pass", exc.exception.args[0]) + def test_query_defaults(self): from google.cloud.bigquery.job import QueryJob diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 3e642142d..d97efd946 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -3176,10 +3176,16 @@ def _verifyResourceProperties(self, job, resource): self.assertEqual(job.destination_uris, config["destinationUris"]) - table_ref = config["sourceTable"] - self.assertEqual(job.source.project, table_ref["projectId"]) - self.assertEqual(job.source.dataset_id, table_ref["datasetId"]) - self.assertEqual(job.source.table_id, table_ref["tableId"]) + if "sourceTable" in config: + table_ref = config["sourceTable"] + self.assertEqual(job.source.project, table_ref["projectId"]) + self.assertEqual(job.source.dataset_id, table_ref["datasetId"]) + self.assertEqual(job.source.table_id, table_ref["tableId"]) + else: + model_ref = config["sourceModel"] + self.assertEqual(job.source.project, model_ref["projectId"]) + self.assertEqual(job.source.dataset_id, model_ref["datasetId"]) + self.assertEqual(job.source.model_id, model_ref["modelId"]) if "compression" in config: self.assertEqual(job.compression, config["compression"]) @@ -3281,6 +3287,28 @@ def test_from_api_repr_bare(self): self.assertIs(job._client, client) self._verifyResourceProperties(job, RESOURCE) + def test_from_api_repr_for_model(self): + self._setUpConstants() + client = _make_client(project=self.PROJECT) + RESOURCE = { + "id": self.JOB_ID, + "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, + "configuration": { + "extract": { + "sourceModel": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "modelId": "model_id", + }, + "destinationUris": [self.DESTINATION_URI], + } + }, + } + klass = self._get_target_class() + job = klass.from_api_repr(RESOURCE, client=client) + self.assertIs(job._client, client) + self._verifyResourceProperties(job, RESOURCE) + def test_from_api_repr_w_properties(self): from google.cloud.bigquery.job import Compression