diff --git a/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py b/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py new file mode 100644 index 0000000000..8493188355 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_tabular_bigquery_sample.py @@ -0,0 +1,43 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample] +def create_and_import_dataset_tabular_bigquery_sample( + project: str, + location: str, + display_name: str, + src_uris: Union[str, List[str]], + sync: bool, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TabularDataset.create( + display_name=display_name, + bq_source=src_uris, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample] diff --git a/samples/model-builder/create_and_import_dataset_tabular_bigquery_test.py b/samples/model-builder/create_and_import_dataset_tabular_bigquery_test.py new file mode 100644 index 0000000000..82b9fc2eb6 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_tabular_bigquery_test.py @@ -0,0 +1,39 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import create_and_import_dataset_tabular_bigquery_sample +import test_constants as constants + + +def test_create_and_import_dataset_tabular_bigquery_sample(mock_sdk_init, mock_create_tabular_dataset): + + create_and_import_dataset_tabular_bigquery_sample.create_and_import_dataset_tabular_bigquery_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.BIGQUERY_SOURCE, + display_name=constants.DISPLAY_NAME, + sync=True, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_tabular_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + bq_source=constants.BIGQUERY_SOURCE, + sync=True, + ) diff --git a/samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py b/samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py index dd6f09a799..409398b5f4 100644 --- a/samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py +++ b/samples/model-builder/create_and_import_dataset_tabular_gcs_sample.py @@ -33,8 +33,7 @@ def create_and_import_dataset_tabular_gcs_sample( sync=sync, ) - if not sync: - ds.wait() + ds.wait() print(ds.display_name) print(ds.resource_name) diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index 50dfa968b4..4cf4b51bc1 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -40,6 +40,7 @@ TRAINING_JOB_NAME = f"{PARENT}/trainingJobs/{RESOURCE_ID}" GCS_SOURCES = ["gs://bucket1/source1.jsonl", "gs://bucket7/source4.jsonl"] +BIGQUERY_SOURCE = "bq://bigquery-public-data.ml_datasets.iris" GCS_DESTINATION = "gs://bucket3/output-dir/" TRAINING_FRACTION_SPLIT = 0.7