Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass credentials to BQ and GCS clients #469

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/datasets/dataset.py
Expand Up @@ -68,7 +68,7 @@ def __init__(
Optional location to retrieve dataset from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to upload this model. Overrides
Custom credentials to use to retreive this Dataset. Overrides
credentials set in aiplatform.init.
"""

Expand Down
34 changes: 25 additions & 9 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -73,20 +73,28 @@ def column_names(self) -> List[str]:
gcs_source_uris.sort()

# Get the first file in sorted list
return TabularDataset._retrieve_gcs_source_columns(
self.project, gcs_source_uris[0]
return self._retrieve_gcs_source_columns(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any use case where the user would want to pass in different credentials when calling dataset.column_names?

If not, this looks fine.

project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials,
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return TabularDataset._retrieve_bq_source_columns(
self.project, bq_table_uri
return self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials,
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")

@staticmethod
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage

Example Usage:
Expand All @@ -104,7 +112,8 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
gcs_csv_file_path (str):
Required. A full path to a CSV files stored on Google Cloud Storage.
Must include "gs://" prefix.

credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
List[str]
A list of columns names in the CSV file.
Expand All @@ -116,7 +125,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
gcs_csv_file_path
)
client = storage.Client(project=project)
client = storage.Client(project=project, credentials=credentials)
bucket = client.bucket(gcs_bucket)
blob = bucket.blob(gcs_blob)

Expand All @@ -135,6 +144,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
line += blob.download_as_bytes(
start=start_index, end=start_index + increment
).decode("utf-8")

first_new_line_index = line.find("\n")
start_index += increment

Expand All @@ -156,7 +166,11 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
return next(csv_reader)

@staticmethod
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
def _retrieve_bq_source_columns(
project: str,
bq_table_uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery

Example Usage:
Expand All @@ -174,6 +188,8 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.

Returns:
List[str]
Expand All @@ -185,7 +201,7 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
if bq_table_uri.startswith(prefix):
bq_table_uri = bq_table_uri[len(prefix) :]

client = bigquery.Client(project=project)
client = bigquery.Client(project=project, credentials=credentials)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]
Expand Down
58 changes: 50 additions & 8 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -341,16 +341,30 @@ def list_datasets_mock():

@pytest.fixture
def gcs_client_download_as_bytes_mock():
with patch.object(storage.Blob, "download_as_bytes") as bigquery_blob_mock:
bigquery_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
yield bigquery_blob_mock
with patch.object(storage.Blob, "download_as_bytes") as gcs_blob_mock:
gcs_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
yield gcs_blob_mock


@pytest.fixture
def bigquery_client_mock():
with patch.object(bigquery.Client, "get_table") as bigquery_client_mock:
bigquery_client_mock.return_value = bigquery.Table("project.dataset.table")
yield bigquery_client_mock
def gcs_client_mock():
with patch.object(storage, "Client") as client_mock:
yield client_mock


@pytest.fixture
def bq_client_mock():
with patch.object(bigquery, "Client") as client_mock:
yield client_mock


@pytest.fixture
def bigquery_client_table_mock():
with patch.object(bigquery.Client, "get_table") as bigquery_client_table_mock:
bigquery_client_table_mock.return_value = bigquery.Table(
"project.dataset.table"
)
yield bigquery_client_table_mock


@pytest.fixture
Expand Down Expand Up @@ -995,9 +1009,37 @@ def test_tabular_dataset_column_name_gcs(self):

assert my_dataset.column_names == ["column_1", "column_2"]

@pytest.mark.usefixtures("get_dataset_tabular_gcs_mock")
def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock):
creds = auth_credentials.AnonymousCredentials()
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds)

# we are just testing creds passing
# this exception if from the mock not returning
# the csv data which is tested above
try:
my_dataset.column_names
except StopIteration:
pass

gcs_client_mock.assert_called_once_with(
project=_TEST_PROJECT, credentials=creds
)

@pytest.mark.usefixtures("get_dataset_tabular_bq_mock",)
def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
creds = auth_credentials.AnonymousCredentials()
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds)

my_dataset.column_names

assert bq_client_mock.call_args_list[0] == mock.call(
project=_TEST_PROJECT, credentials=creds
)

@pytest.mark.usefixtures(
"get_dataset_tabular_bq_mock",
"bigquery_client_mock",
"bigquery_client_table_mock",
"bigquery_table_schema_mock",
)
def test_tabular_dataset_column_name_bigquery(self):
Expand Down