Skip to content

Commit

Permalink
fix: add v1 conversion value rule
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha-gitg authored and morgandu committed Apr 7, 2021
1 parent 3ce0163 commit 6356e96
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 37 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/helpers/_decorators.py
Expand Up @@ -68,3 +68,5 @@ def _from_map(map_):

marshal = Marshal(name="google.cloud.aiplatform.v1beta1")
marshal.register(Value, ConversionValueRule(marshal=marshal))
marshal = Marshal(name="google.cloud.aiplatform.v1")
marshal.register(Value, ConversionValueRule(marshal=marshal))
8 changes: 7 additions & 1 deletion google/cloud/aiplatform/initializer.py
Expand Up @@ -38,6 +38,7 @@
encryption_spec_v1beta1 as gca_encryption_spec_v1beta1,
)


class _Config:
"""Stores common parameters and options for API calls."""

Expand Down Expand Up @@ -100,7 +101,12 @@ def get_encryption_spec(
self,
encryption_spec_key_name: Optional[str],
select_version: Optional[str] = compat.DEFAULT_VERSION,
) -> Optional[Union[gca_encryption_spec_v1.EncryptionSpec, gca_encryption_spec_v1beta1.EncryptionSpec]]:
) -> Optional[
Union[
gca_encryption_spec_v1.EncryptionSpec,
gca_encryption_spec_v1beta1.EncryptionSpec,
]
]:
"""Creates a gca_encryption_spec.EncryptionSpec instance from the given key name.
If the provided key name is None, it uses the default key name if provided.
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/jobs.py
Expand Up @@ -487,7 +487,7 @@ def create(
# Optional Fields
gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name,
select_version=select_version
select_version=select_version,
)

if model_parameters:
Expand Down
9 changes: 0 additions & 9 deletions tests/unit/aiplatform/test_end_to_end.py
Expand Up @@ -25,15 +25,6 @@
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import training_jobs

from google.cloud.aiplatform_v1beta1.types import (
dataset as gca_dataset_v1beta1,
encryption_spec as gca_encryption_spec_v1beta1,
io as gca_io_v1beta1,
model as gca_model_v1beta1,
pipeline_state as gca_pipeline_state_v1beta1,
training_pipeline as gca_training_pipeline_v1beta1,
)

from google.cloud.aiplatform_v1.types import (
dataset as gca_dataset,
encryption_spec as gca_encryption_spec,
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/aiplatform/test_endpoints.py
Expand Up @@ -29,9 +29,6 @@
from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils

from google.cloud.aiplatform_v1beta1.services.model_service import (
client as model_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.endpoint_service import (
client as endpoint_service_client_v1beta1,
)
Expand All @@ -40,11 +37,9 @@
)
from google.cloud.aiplatform_v1beta1.types import (
endpoint as gca_endpoint_v1beta1,
model as gca_model_v1beta1,
machine_resources as gca_machine_resources_v1beta1,
prediction_service as gca_prediction_service_v1beta1,
endpoint_service as gca_endpoint_service_v1beta1,
encryption_spec as gca_encryption_spec_v1beta1,
)

from google.cloud.aiplatform_v1.services.model_service import (
Expand Down Expand Up @@ -99,7 +94,9 @@
_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100"
_TEST_ACCELERATOR_COUNT = 2

_TEST_EXPLANATIONS = [gca_prediction_service_v1beta1.explanation.Explanation(attributions=[])]
_TEST_EXPLANATIONS = [
gca_prediction_service_v1beta1.explanation.Explanation(attributions=[])
]

_TEST_ATTRIBUTIONS = [
gca_prediction_service_v1beta1.explanation.Attribution(
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/aiplatform/test_jobs.py
Expand Up @@ -39,19 +39,15 @@
batch_prediction_job as gca_batch_prediction_job_v1beta1,
explanation as gca_explanation_v1beta1,
io as gca_io_v1beta1,
job_state as gca_job_state_v1beta1,
machine_resources as gca_machine_resources_v1beta1,
)

from google.cloud.aiplatform_v1.services.job_service import (
client as job_service_client,
)
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client

from google.cloud.aiplatform_v1.types import (
batch_prediction_job as gca_batch_prediction_job,
io as gca_io,
job_state as gca_job_state,
machine_resources as gca_machine_resources,
)

_TEST_PROJECT = "test-project"
Expand Down Expand Up @@ -485,7 +481,9 @@ def test_batch_predict_gcs_source_bq_dest(

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_explanations_mock, sync):
def test_batch_predict_with_all_args(
self, create_batch_prediction_job_with_explanations_mock, sync
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
creds = auth_credentials.AnonymousCredentials()

Expand Down Expand Up @@ -518,7 +516,9 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_expl
model=_TEST_MODEL_NAME,
input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig(
instances_format="jsonl",
gcs_source=gca_io_v1beta1.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
gcs_source=gca_io_v1beta1.GcsSource(
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
),
),
output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig(
gcs_destination=gca_io_v1beta1.GcsDestination(
Expand Down
28 changes: 15 additions & 13 deletions tests/unit/aiplatform/test_models.py
Expand Up @@ -43,7 +43,6 @@
env_var as gca_env_var_v1beta1,
explanation as gca_explanation_v1beta1,
io as gca_io_v1beta1,
job_state as gca_job_state_v1beta1,
model as gca_model_v1beta1,
endpoint as gca_endpoint_v1beta1,
machine_resources as gca_machine_resources_v1beta1,
Expand All @@ -55,15 +54,12 @@
from google.cloud.aiplatform_v1.services.endpoint_service import (
client as endpoint_service_client,
)
from google.cloud.aiplatform_v1.services.job_service import (
client as job_service_client,
)
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
from google.cloud.aiplatform_v1.services.model_service import (
client as model_service_client,
)
from google.cloud.aiplatform_v1.types import (
batch_prediction_job as gca_batch_prediction_job,
env_var as gca_env_var,
io as gca_io,
job_state as gca_job_state,
model as gca_model,
Expand Down Expand Up @@ -184,6 +180,7 @@ def get_model_mock():
)
yield get_model_mock


@pytest.fixture
def get_model_with_explanations_mock():
with mock.patch.object(
Expand All @@ -194,6 +191,7 @@ def get_model_with_explanations_mock():
)
yield get_model_mock


@pytest.fixture
def get_model_with_custom_location_mock():
with mock.patch.object(
Expand Down Expand Up @@ -244,7 +242,6 @@ def upload_model_with_explanations_mock():
yield upload_model_mock



@pytest.fixture
def upload_model_with_custom_project_mock():
with mock.patch.object(
Expand Down Expand Up @@ -300,7 +297,6 @@ def deploy_model_mock():
yield deploy_model_mock



@pytest.fixture
def deploy_model_with_explanations_mock():
with mock.patch.object(
Expand Down Expand Up @@ -343,6 +339,7 @@ def create_batch_prediction_job_mock():
create_batch_prediction_job_mock.return_value = batch_prediction_job_mock
yield create_batch_prediction_job_mock


@pytest.fixture
def create_batch_prediction_job_with_explanations_mock():
with mock.patch.object(
Expand All @@ -355,6 +352,7 @@ def create_batch_prediction_job_with_explanations_mock():
create_batch_prediction_job_mock.return_value = batch_prediction_job_mock
yield create_batch_prediction_job_mock


@pytest.fixture
def create_client_mock():
with mock.patch.object(
Expand Down Expand Up @@ -746,7 +744,9 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync):
"get_endpoint_mock", "get_model_mock", "create_endpoint_mock"
)
@pytest.mark.parametrize("sync", [True, False])
def test_deploy_no_endpoint_with_explanations(self, deploy_model_with_explanations_mock, sync):
def test_deploy_no_endpoint_with_explanations(
self, deploy_model_with_explanations_mock, sync
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
test_model = models.Model(_TEST_ID)
test_endpoint = test_model.deploy(
Expand Down Expand Up @@ -834,9 +834,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
),
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
instances_format="jsonl",
gcs_source=gca_io.GcsSource(
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
),
gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
),
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
gcs_destination=gca_io.GcsDestination(
Expand Down Expand Up @@ -940,7 +938,9 @@ def test_batch_predict_gcs_source_bq_dest(

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_explanations_mock, sync):
def test_batch_predict_with_all_args(
self, create_batch_prediction_job_with_explanations_mock, sync
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
test_model = models.Model(_TEST_ID)
creds = auth_credentials.AnonymousCredentials()
Expand Down Expand Up @@ -977,7 +977,9 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_with_expl
),
input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig(
instances_format="jsonl",
gcs_source=gca_io_v1beta1.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]),
gcs_source=gca_io_v1beta1.GcsSource(
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
),
),
output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig(
gcs_destination=gca_io_v1beta1.GcsDestination(
Expand Down

0 comments on commit 6356e96

Please sign in to comment.