Skip to content

Commit

Permalink
Implement initial Dataset class (googleapis#31)
Browse files Browse the repository at this point in the history
* Dataset class (#1)

* Add global_config unset project error

* Add two validation functions to utils

* Implement initial Dataset class

* Lint utils.py

* Address reviewer comments + remove aip alias

* Change re.Match to typing.Match

* Lint with Py 3.8

* Address flake8 errors, remove unused vars
  • Loading branch information
vinnysenthil committed Oct 20, 2020
1 parent e43bee1 commit d65512e
Show file tree
Hide file tree
Showing 9 changed files with 688 additions and 42 deletions.
3 changes: 2 additions & 1 deletion google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform.datasets import Dataset

"""
Usage:
Expand All @@ -29,4 +30,4 @@
init = initializer.global_config.init


__all__ = ()
__all__ = ("gapic", "Model", "Dataset")
395 changes: 393 additions & 2 deletions google/cloud/aiplatform/datasets.py

Large diffs are not rendered by default.

19 changes: 13 additions & 6 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@


import logging
from typing import Dict, Optional, Type
from typing import Optional, Type

from google.api_core import client_options
import google.auth
from google.auth import credentials as auth_credentials
from google.auth.exceptions import GoogleAuthError
from google.cloud.aiplatform import utils


Expand Down Expand Up @@ -76,13 +77,21 @@ def project(self) -> str:
if self._project:
return self._project

_, project_id = google.auth.default()
try:
_, project_id = google.auth.default()
except GoogleAuthError:
raise GoogleAuthError(
"Unable to find your project. Please provide a project ID by:"
"\n- Passing a constructor argument"
"\n- Using aiplatform.init()"
"\n- Setting a GCP environment variable"
)
return project_id

@property
def location(self) -> str:
"""Default location."""
return self._location if self._location else utils.DEFAULT_REGION
return self._location or utils.DEFAULT_REGION

@property
def experiment(self) -> Optional[str]:
Expand All @@ -100,9 +109,7 @@ def credentials(self) -> Optional[auth_credentials.Credentials]:
return self._credentials

def get_client_options(
self,
location_override: Optional[str] = None,
prediction_client: bool = False,
self, location_override: Optional[str] = None, prediction_client: bool = False,
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Expand Down
4 changes: 0 additions & 4 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
# limitations under the License.
#

from google.cloud.aiplatform_v1beta1.services.model_service import (
client as model_client,
)

from typing import Dict, Optional, Sequence

from google.auth import credentials as auth_credentials
Expand Down
20 changes: 9 additions & 11 deletions google/cloud/aiplatform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import re

from typing import Optional, TypeVar, Union
from typing import Optional, TypeVar, Match
from collections import namedtuple

from google.cloud.aiplatform_v1beta1.services.dataset_service import (
Expand All @@ -43,19 +43,12 @@
RESOURCE_NAME_PATTERN = re.compile(
r"^projects\/(?P<project>[\w-]+)\/locations\/(?P<location>[\w-]+)\/(?P<resource>\w+)\/(?P<id>\d+)$"
)
RESOURCE_ID_PATTERN = re.compile(r"^\d+$")

Fields = namedtuple(
"Fields",
[
"project",
"location",
"resource",
"id",
],
)
Fields = namedtuple("Fields", ["project", "location", "resource", "id"],)


def _match_to_fields(match: re.Match) -> Optional[Fields]:
def _match_to_fields(match: Match) -> Optional[Fields]:
"""Normalize RegEx groups from resource name pattern Match to class Fields"""
if not match:
return None
Expand All @@ -68,6 +61,11 @@ def _match_to_fields(match: re.Match) -> Optional[Fields]:
)


def validate_id(resource_id: str) -> bool:
"""Validate int64 resource ID number"""
return bool(RESOURCE_ID_PATTERN.match(resource_id))


def extract_fields_from_resource_name(
resource_name: str, resource_noun: Optional[str] = None
) -> Optional[Fields]:
Expand Down
258 changes: 258 additions & 0 deletions tests/unit/aiplatform/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
# -*- coding: utf-8 -*-

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import pytest

from unittest import mock
from importlib import reload
from unittest.mock import patch

from google.api_core import operation
from google.auth.exceptions import GoogleAuthError
from google.auth import credentials as auth_credentials

from google.cloud import aiplatform
from google.cloud.aiplatform import Dataset
from google.cloud.aiplatform import initializer

from google.cloud.aiplatform_v1beta1 import GcsSource
from google.cloud.aiplatform_v1beta1 import GcsDestination
from google.cloud.aiplatform_v1beta1 import ImportDataConfig
from google.cloud.aiplatform_v1beta1 import ExportDataConfig
from google.cloud.aiplatform_v1beta1 import DatasetServiceClient
from google.cloud.aiplatform_v1beta1 import Dataset as GapicDataset

_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_ALT_LOCATION = "europe-west4"
_TEST_ID = "1028944691210842416"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
_TEST_ALT_NAME = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_ALT_LOCATION}/datasets/{_TEST_ID}"
)

_TEST_INVALID_LOCATION = "us-central2"
_TEST_INVALID_NAME = f"prj/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/{_TEST_ID}"

_TEST_LABEL = {"team": "experimentation", "trial_id": "x435"}
_TEST_DISPLAY_NAME = "my_dataset_1234"
_TEST_METADATA_SCHEMA_URI = "gs://my-bucket/schema-9876.yaml"

_TEST_IMPORT_SCHEMA_URI = "gs://google-cloud-aiplatform/schemas/1.0.0.yaml"
_TEST_SOURCE_URI = "gs://my-bucket/my_index_file.jsonl"
_TEST_SOURCE_URIS = [
"gs://my-bucket/index_file_1.jsonl",
"gs://my-bucket/index_file_2.jsonl",
"gs://my-bucket/index_file_3.jsonl",
]
_TEST_INVALID_SOURCE_URIS = ["gs://my-bucket/index_file_1.jsonl", 123]
_TEST_DATA_LABEL_ITEMS = {}

_TEST_OUTPUT_DIR = "gs://my-output-bucket"


# TODO(b/171333554): Move reusable test fixtures to conftest.py file
class TestDataset:
def setup_method(self):
reload(initializer)
reload(aiplatform)

@pytest.fixture
def get_dataset_mock(self):
with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock:
get_dataset_mock.return_value = GapicDataset(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI,
labels=_TEST_LABEL,
name=_TEST_NAME,
)
yield get_dataset_mock

@pytest.fixture
def get_dataset_without_name_mock(self):
with patch.object(DatasetServiceClient, "get_dataset") as get_dataset_mock:
get_dataset_mock.return_value = GapicDataset(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI,
labels=_TEST_LABEL,
)
yield get_dataset_mock

@pytest.fixture
def create_dataset_mock(self):
with patch.object(
DatasetServiceClient, "create_dataset"
) as create_dataset_mock:
create_dataset_lro_mock = mock.Mock(operation.Operation)
create_dataset_lro_mock.result.return_value = GapicDataset(
name=_TEST_NAME, display_name=_TEST_DISPLAY_NAME
)
create_dataset_mock.return_value = create_dataset_lro_mock
yield create_dataset_mock

@pytest.fixture
def import_data_mock(self):
with patch.object(DatasetServiceClient, "import_data") as import_data_mock:
import_data_mock.return_value = mock.Mock(operation.Operation)
yield import_data_mock

@pytest.fixture
def export_data_mock(self):
with patch.object(DatasetServiceClient, "export_data") as export_data_mock:
export_data_mock.return_value = mock.Mock(operation.Operation)
yield export_data_mock

def test_init_dataset(self, get_dataset_mock):
Dataset(dataset_name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

def test_init_dataset_with_id_only(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
Dataset(dataset_name=_TEST_ID)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

@pytest.mark.usefixtures("get_dataset_without_name_mock")
@patch.dict(
os.environ, {"GOOGLE_CLOUD_PROJECT": "", "GOOGLE_APPLICATION_CREDENTIALS": ""}
)
def test_init_dataset_with_id_only_without_project_or_location(self):
with pytest.raises(GoogleAuthError):
Dataset(
dataset_name=_TEST_ID,
credentials=auth_credentials.AnonymousCredentials(),
)

def test_init_dataset_with_location_override(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
Dataset(dataset_name=_TEST_ID, location=_TEST_ALT_LOCATION)
get_dataset_mock.assert_called_once_with(name=_TEST_ALT_NAME)

@pytest.mark.usefixtures("get_dataset_mock")
def test_init_dataset_with_invalid_name(self):
with pytest.raises(ValueError):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
Dataset(dataset_name=_TEST_INVALID_NAME)

@pytest.mark.usefixtures("get_dataset_mock")
def test_create_dataset(self, create_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)

Dataset.create(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI,
labels=_TEST_LABEL,
)

expected_dataset = GapicDataset(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI,
labels=_TEST_LABEL,
)

create_dataset_mock.assert_called_once_with(
parent=_TEST_PARENT, dataset=expected_dataset, metadata=()
)

@pytest.mark.usefixtures("get_dataset_mock")
def test_create_and_import_dataset(self, create_dataset_mock, import_data_mock):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = Dataset.create(
display_name=_TEST_DISPLAY_NAME,
source=_TEST_SOURCE_URI,
labels=_TEST_LABEL,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI,
import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
data_items_labels=_TEST_DATA_LABEL_ITEMS,
)

expected_dataset = GapicDataset(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI,
labels=_TEST_LABEL,
)

expected_import_config = ImportDataConfig(
gcs_source=GcsSource(uris=[_TEST_SOURCE_URI]),
import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
data_item_labels=_TEST_DATA_LABEL_ITEMS,
)

create_dataset_mock.assert_called_once_with(
parent=_TEST_PARENT, dataset=expected_dataset, metadata=()
)

import_data_mock.assert_called_once_with(
name=_TEST_NAME, import_configs=[expected_import_config]
)

expected_dataset.name = _TEST_NAME
assert my_dataset._gca_resource == expected_dataset

@pytest.mark.usefixtures("get_dataset_mock")
def test_create_and_import_dataset_without_import_schema_uri(
self, create_dataset_mock
):
with pytest.raises(ValueError):
aiplatform.init(project=_TEST_PROJECT)

Dataset.create(
display_name=_TEST_DISPLAY_NAME,
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI,
labels=_TEST_LABEL,
source=_TEST_SOURCE_URI,
)

@pytest.mark.usefixtures("get_dataset_mock")
def test_import_data(self, import_data_mock):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = Dataset(dataset_name=_TEST_NAME)

my_dataset.import_data(
gcs_source=_TEST_SOURCE_URI,
import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
data_items_labels=_TEST_DATA_LABEL_ITEMS,
)

expected_import_config = ImportDataConfig(
gcs_source=GcsSource(uris=[_TEST_SOURCE_URI]),
import_schema_uri=_TEST_IMPORT_SCHEMA_URI,
data_item_labels=_TEST_DATA_LABEL_ITEMS,
)

import_data_mock.assert_called_once_with(
name=_TEST_NAME, import_configs=[expected_import_config]
)

@pytest.mark.usefixtures("get_dataset_mock")
def test_export_data(self, export_data_mock):
aiplatform.init(project=_TEST_PROJECT)

my_dataset = Dataset(dataset_name=_TEST_NAME)

my_dataset.export_data(output_dir=_TEST_OUTPUT_DIR)

expected_export_config = ExportDataConfig(
gcs_destination=GcsDestination(output_uri_prefix=_TEST_OUTPUT_DIR)
)

export_data_mock.assert_called_once_with(
name=_TEST_NAME, export_config=expected_export_config
)
1 change: 0 additions & 1 deletion tests/unit/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pytest
import importlib

from google.api_core import client_options
import google.auth
from google.auth import credentials
from google.cloud.aiplatform import initializer
Expand Down

0 comments on commit d65512e

Please sign in to comment.