Skip to content

Commit

Permalink
fix: pass credentials to BQ and GCS clients (#469)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg committed Jun 11, 2021
1 parent c2cf612 commit 481d172
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 18 deletions.
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/datasets/dataset.py
Expand Up @@ -68,7 +68,7 @@ def __init__(
Optional location to retrieve dataset from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to upload this model. Overrides
Custom credentials to use to retreive this Dataset. Overrides
credentials set in aiplatform.init.
"""

Expand Down
34 changes: 25 additions & 9 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -73,20 +73,28 @@ def column_names(self) -> List[str]:
gcs_source_uris.sort()

# Get the first file in sorted list
return TabularDataset._retrieve_gcs_source_columns(
self.project, gcs_source_uris[0]
return self._retrieve_gcs_source_columns(
project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials,
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return TabularDataset._retrieve_bq_source_columns(
self.project, bq_table_uri
return self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials,
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")

@staticmethod
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
Example Usage:
Expand All @@ -104,7 +112,8 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
gcs_csv_file_path (str):
Required. A full path to a CSV files stored on Google Cloud Storage.
Must include "gs://" prefix.
credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
List[str]
A list of columns names in the CSV file.
Expand All @@ -116,7 +125,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
gcs_csv_file_path
)
client = storage.Client(project=project)
client = storage.Client(project=project, credentials=credentials)
bucket = client.bucket(gcs_bucket)
blob = bucket.blob(gcs_blob)

Expand All @@ -135,6 +144,7 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
line += blob.download_as_bytes(
start=start_index, end=start_index + increment
).decode("utf-8")

first_new_line_index = line.find("\n")
start_index += increment

Expand All @@ -156,7 +166,11 @@ def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[s
return next(csv_reader)

@staticmethod
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
def _retrieve_bq_source_columns(
project: str,
bq_table_uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery
Example Usage:
Expand All @@ -174,6 +188,8 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.
Returns:
List[str]
Expand All @@ -185,7 +201,7 @@ def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
if bq_table_uri.startswith(prefix):
bq_table_uri = bq_table_uri[len(prefix) :]

client = bigquery.Client(project=project)
client = bigquery.Client(project=project, credentials=credentials)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]
Expand Down
58 changes: 50 additions & 8 deletions tests/unit/aiplatform/test_datasets.py
Expand Up @@ -341,16 +341,30 @@ def list_datasets_mock():

@pytest.fixture
def gcs_client_download_as_bytes_mock():
with patch.object(storage.Blob, "download_as_bytes") as bigquery_blob_mock:
bigquery_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
yield bigquery_blob_mock
with patch.object(storage.Blob, "download_as_bytes") as gcs_blob_mock:
gcs_blob_mock.return_value = b'"column_1","column_2"\n0, 1'
yield gcs_blob_mock


@pytest.fixture
def bigquery_client_mock():
with patch.object(bigquery.Client, "get_table") as bigquery_client_mock:
bigquery_client_mock.return_value = bigquery.Table("project.dataset.table")
yield bigquery_client_mock
def gcs_client_mock():
with patch.object(storage, "Client") as client_mock:
yield client_mock


@pytest.fixture
def bq_client_mock():
with patch.object(bigquery, "Client") as client_mock:
yield client_mock


@pytest.fixture
def bigquery_client_table_mock():
with patch.object(bigquery.Client, "get_table") as bigquery_client_table_mock:
bigquery_client_table_mock.return_value = bigquery.Table(
"project.dataset.table"
)
yield bigquery_client_table_mock


@pytest.fixture
Expand Down Expand Up @@ -995,9 +1009,37 @@ def test_tabular_dataset_column_name_gcs(self):

assert my_dataset.column_names == ["column_1", "column_2"]

@pytest.mark.usefixtures("get_dataset_tabular_gcs_mock")
def test_tabular_dataset_column_name_gcs_with_creds(self, gcs_client_mock):
creds = auth_credentials.AnonymousCredentials()
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds)

# we are just testing creds passing
# this exception if from the mock not returning
# the csv data which is tested above
try:
my_dataset.column_names
except StopIteration:
pass

gcs_client_mock.assert_called_once_with(
project=_TEST_PROJECT, credentials=creds
)

@pytest.mark.usefixtures("get_dataset_tabular_bq_mock",)
def test_tabular_dataset_column_name_bq_with_creds(self, bq_client_mock):
creds = auth_credentials.AnonymousCredentials()
my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME, credentials=creds)

my_dataset.column_names

assert bq_client_mock.call_args_list[0] == mock.call(
project=_TEST_PROJECT, credentials=creds
)

@pytest.mark.usefixtures(
"get_dataset_tabular_bq_mock",
"bigquery_client_mock",
"bigquery_client_table_mock",
"bigquery_table_schema_mock",
)
def test_tabular_dataset_column_name_bigquery(self):
Expand Down

0 comments on commit 481d172

Please sign in to comment.