Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass credentials to BQ and GCS clients #469

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 24 additions & 9 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -73,20 +73,28 @@ def column_names(self) -> List[str]:
gcs_source_uris.sort()

# Get the first file in sorted list
return TabularDataset._retrieve_gcs_source_columns(
self.project, gcs_source_uris[0]
return self._retrieve_gcs_source_columns(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any use case where the user would want to pass in different credentials when calling dataset.column_names?

If not, this looks fine.

project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return TabularDataset._retrieve_bq_source_columns(
self.project, bq_table_uri
return self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")

@staticmethod
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None
) -> List[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage

Example Usage:
Expand All @@ -104,7 +112,8 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
gcs_csv_file_path (str):
Required. A full path to a CSV files stored on Google Cloud Storage.
Must include "gs://" prefix.

credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
List[str]
A list of columns names in the CSV file.
Expand All @@ -116,7 +125,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
gcs_csv_file_path
)
client = storage.Client(project=project)
client = storage.Client(project=project, credentials=credentials)
bucket = client.bucket(gcs_bucket)
blob = bucket.blob(gcs_blob)

Expand Down Expand Up @@ -156,7 +165,11 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
return next(csv_reader)

@staticmethod
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
def _retrieve_bq_source_columns(
project: str,
bq_table_uri:str,
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
credentials: Optional[auth_credentials.Credentials] = None
) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery

Example Usage:
Expand All @@ -174,6 +187,8 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.

Returns:
List[str]
Expand All @@ -185,7 +200,7 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
if bq_table_uri.startswith(prefix):
bq_table_uri = bq_table_uri[len(prefix) :]

client = bigquery.Client(project=project)
client = bigquery.Client(project=project, credentials=credentials)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -345,6 +345,17 @@ def gcs_client_download_as_bytes_mock():
bigquery_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
yield bigquery_blob_mock

@pytest.fixture
def gcs_client_download_as_bytes_mock_with_creds():

with patch.object(storage, 'Client', autospec=True) as client_mock:
return client_mock


# with patch.object(storage.Blob, "download_as_bytes") as bigquery_blob_mock:
# bigquery_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
# yield bigquery_blob_mock


@pytest.fixture
def bigquery_client_mock():
Expand Down Expand Up @@ -995,6 +1006,16 @@ def test_tabular_dataset_column_name_gcs(self):

assert my_dataset.column_names == ["column_1", "column_2"]

@pytest.mark.usefixtures(
"get_dataset_tabular_gcs_mock", "gcs_client_download_as_bytes_mock"
)
def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_download_as_bytes_mock_with_creds):
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME)

assert my_dataset.column_names == ["column_1", "column_2"]

assert gcs_client_download_as_bytes_mock_with_creds.assert_called_once_with(project=_TEST_PROJECT)

@pytest.mark.usefixtures(
"get_dataset_tabular_bq_mock",
"bigquery_client_mock",
Expand Down