From 481d172542ffd80e18f4fab5b01945be17d5e18c Mon Sep 17 00:00:00 2001 From: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com> Date: Thu, 10 Jun 2021 20:00:21 -0400 Subject: [PATCH] fix: pass credentials to BQ and GCS clients (#469) --- google/cloud/aiplatform/datasets/dataset.py | 2 +- .../aiplatform/datasets/tabular_dataset.py | 34 ++++++++--- tests/unit/aiplatform/test_datasets.py | 58 ++++++++++++++++--- 3 files changed, 76 insertions(+), 18 deletions(-) diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 1eb1663b2b..df402d0c99 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -68,7 +68,7 @@ def __init__( Optional location to retrieve dataset from. If not set, location set in aiplatform.init will be used. credentials (auth_credentials.Credentials): - Custom credentials to use to upload this model. Overrides + Custom credentials to use to retreive this Dataset. Overrides credentials set in aiplatform.init. """ diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 95f1b16f98..71c9d4f7d7 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -73,20 +73,28 @@ def column_names(self) -> List[str]: gcs_source_uris.sort() # Get the first file in sorted list - return TabularDataset._retrieve_gcs_source_columns( - self.project, gcs_source_uris[0] + return self._retrieve_gcs_source_columns( + project=self.project, + gcs_csv_file_path=gcs_source_uris[0], + credentials=self.credentials, ) elif bq_source: bq_table_uri = bq_source.get("uri") if bq_table_uri: - return TabularDataset._retrieve_bq_source_columns( - self.project, bq_table_uri + return self._retrieve_bq_source_columns( + project=self.project, + bq_table_uri=bq_table_uri, + credentials=self.credentials, ) raise RuntimeError("No valid CSV or BigQuery datasource found.") @staticmethod - def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]: + def _retrieve_gcs_source_columns( + project: str, + gcs_csv_file_path: str, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[str]: """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage Example Usage: @@ -104,7 +112,8 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s gcs_csv_file_path (str): Required. A full path to a CSV files stored on Google Cloud Storage. Must include "gs://" prefix. - + credentials (auth_credentials.Credentials): + Credentials to use to with GCS Client. Returns: List[str] A list of columns names in the CSV file. @@ -116,7 +125,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path( gcs_csv_file_path ) - client = storage.Client(project=project) + client = storage.Client(project=project, credentials=credentials) bucket = client.bucket(gcs_bucket) blob = bucket.blob(gcs_blob) @@ -135,6 +144,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s line += blob.download_as_bytes( start=start_index, end=start_index + increment ).decode("utf-8") + first_new_line_index = line.find("\n") start_index += increment @@ -156,7 +166,11 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s return next(csv_reader) @staticmethod - def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]: + def _retrieve_bq_source_columns( + project: str, + bq_table_uri: str, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[str]: """Retrieve the columns from a table on Google BigQuery Example Usage: @@ -174,6 +188,8 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]: bq_table_uri (str): Required. A URI to a BigQuery table. Can include "bq://" prefix but not required. + credentials (auth_credentials.Credentials): + Credentials to use with BQ Client. Returns: List[str] @@ -185,7 +201,7 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]: if bq_table_uri.startswith(prefix): bq_table_uri = bq_table_uri[len(prefix) :] - client = bigquery.Client(project=project) + client = bigquery.Client(project=project, credentials=credentials) table = client.get_table(bq_table_uri) schema = table.schema return [schema.name for schema in schema] diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 4c2a75c393..5da47bea59 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -341,16 +341,30 @@ def list_datasets_mock(): @pytest.fixture def gcs_client_download_as_bytes_mock(): - with patch.object(storage.Blob, "download_as_bytes") as bigquery_blob_mock: - bigquery_blob_mock.return_value = b'"column_1","column_2"\n0, 1' - yield bigquery_blob_mock + with patch.object(storage.Blob, "download_as_bytes") as gcs_blob_mock: + gcs_blob_mock.return_value = b'"column_1","column_2"\n0, 1' + yield gcs_blob_mock @pytest.fixture -def bigquery_client_mock(): - with patch.object(bigquery.Client, "get_table") as bigquery_client_mock: - bigquery_client_mock.return_value = bigquery.Table("project.dataset.table") - yield bigquery_client_mock +def gcs_client_mock(): + with patch.object(storage, "Client") as client_mock: + yield client_mock + + +@pytest.fixture +def bq_client_mock(): + with patch.object(bigquery, "Client") as client_mock: + yield client_mock + + +@pytest.fixture +def bigquery_client_table_mock(): + with patch.object(bigquery.Client, "get_table") as bigquery_client_table_mock: + bigquery_client_table_mock.return_value = bigquery.Table( + "project.dataset.table" + ) + yield bigquery_client_table_mock @pytest.fixture @@ -995,9 +1009,37 @@ def test_tabular_dataset_column_name_gcs(self): assert my_dataset.column_names == ["column_1", "column_2"] + @pytest.mark.usefixtures("get_dataset_tabular_gcs_mock") + def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock): + creds = auth_credentials.AnonymousCredentials() + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds) + + # we are just testing creds passing + # this exception if from the mock not returning + # the csv data which is tested above + try: + my_dataset.column_names + except StopIteration: + pass + + gcs_client_mock.assert_called_once_with( + project=_TEST_PROJECT, credentials=creds + ) + + @pytest.mark.usefixtures("get_dataset_tabular_bq_mock",) + def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock): + creds = auth_credentials.AnonymousCredentials() + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds) + + my_dataset.column_names + + assert bq_client_mock.call_args_list[0] == mock.call( + project=_TEST_PROJECT, credentials=creds + ) + @pytest.mark.usefixtures( "get_dataset_tabular_bq_mock", - "bigquery_client_mock", + "bigquery_client_table_mock", "bigquery_table_schema_mock", ) def test_tabular_dataset_column_name_bigquery(self):