Skip to content

Commit

Permalink
feat: add create_batch_prediction_job samples (#67)
Browse files Browse the repository at this point in the history
* chore: sample tests lint

* lint

* lnt

* lint

* feat: add create_batch_prediction_job samples

* lint
  • Loading branch information
dizcology committed Dec 1, 2020
1 parent 77956b2 commit 96a850f
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 0 deletions.
62 changes: 62 additions & 0 deletions samples/snippets/create_batch_prediction_job_bigquery_sample.py
@@ -0,0 +1,62 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# [START aiplatform_create_batch_prediction_job_bigquery_sample]
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value


def create_batch_prediction_job_bigquery_sample(
project: str,
display_name: str,
model_name: str,
instances_format: str,
bigquery_source_input_uri: str,
predictions_format: str,
bigquery_destination_output_uri: str,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
model_parameters_dict = {}
model_parameters = json_format.ParseDict(model_parameters_dict, Value())

batch_prediction_job = {
"display_name": display_name,
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
"model": model_name,
"model_parameters": model_parameters,
"input_config": {
"instances_format": instances_format,
"bigquery_source": {"input_uri": bigquery_source_input_uri},
},
"output_config": {
"predictions_format": predictions_format,
"bigquery_destination": {"output_uri": bigquery_destination_output_uri},
},
# optional
"generate_explanation": True,
}
parent = f"projects/{project}/locations/{location}"
response = client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
)
print("response:", response)


# [END aiplatform_create_batch_prediction_job_bigquery_sample]
@@ -0,0 +1,82 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from uuid import uuid4

from google.cloud import aiplatform
import pytest

import create_batch_prediction_job_bigquery_sample
import helpers

PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
LOCATION = "us-central1"
MODEL_ID = "3125638878883479552" # bq all
DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}"
BIGQUERY_SOURCE_INPUT_URI = "bq://ucaip-sample-tests.table_test.all_bq_types"
BIGQUERY_DESTINATION_OUTPUT_URI = "bq://ucaip-sample-tests"
INSTANCES_FORMAT = "bigquery"
PREDICTIONS_FORMAT = "bigquery"


@pytest.fixture
def shared_state():
state = {}
yield state


@pytest.fixture
def job_client():
job_client = aiplatform.gapic.JobServiceClient(
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
)
return job_client


@pytest.fixture(scope="function", autouse=True)
def teardown(shared_state, job_client):
yield

job_client.cancel_batch_prediction_job(name=shared_state["batch_prediction_job_name"])

# Waiting until the job is in CANCELLED state.
helpers.wait_for_job_state(
get_job_method=job_client.get_batch_prediction_job,
name=shared_state["batch_prediction_job_name"],
)

job_client.delete_batch_prediction_job(name=shared_state["batch_prediction_job_name"])


def test_ucaip_generated_create_batch_prediction_job_bigquery_sample(
capsys, shared_state
):

model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"

create_batch_prediction_job_bigquery_sample.create_batch_prediction_job_bigquery_sample(
project=PROJECT_ID,
display_name=DISPLAY_NAME,
model_name=model_name,
bigquery_source_input_uri=BIGQUERY_SOURCE_INPUT_URI,
bigquery_destination_output_uri=BIGQUERY_DESTINATION_OUTPUT_URI,
instances_format=INSTANCES_FORMAT,
predictions_format=PREDICTIONS_FORMAT,
)

out, _ = capsys.readouterr()

# Save resource name of the newly created batch prediction job
shared_state["batch_prediction_job_name"] = helpers.get_name(out)
69 changes: 69 additions & 0 deletions samples/snippets/create_batch_prediction_job_sample.py
@@ -0,0 +1,69 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# [START aiplatform_create_batch_prediction_job_sample]
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value


def create_batch_prediction_job_sample(
project: str,
display_name: str,
model_name: str,
instances_format: str,
gcs_source_uri: str,
predictions_format: str,
gcs_destination_output_uri_prefix: str,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
model_parameters_dict = {}
model_parameters = json_format.ParseDict(model_parameters_dict, Value())

batch_prediction_job = {
"display_name": display_name,
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
"model": model_name,
"model_parameters": model_parameters,
"input_config": {
"instances_format": instances_format,
"gcs_source": {"uris": [gcs_source_uri]},
},
"output_config": {
"predictions_format": predictions_format,
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
},
"dedicated_resources": {
"machine_spec": {
"machine_type": "n1-standard-2",
"accelerator_type": aiplatform.gapic.AcceleratorType.NVIDIA_TESLA_K80,
"accelerator_count": 1,
},
"starting_replica_count": 1,
"max_replica_count": 1,
},
}
parent = f"projects/{project}/locations/{location}"
response = client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
)
print("response:", response)


# [END aiplatform_create_batch_prediction_job_sample]
83 changes: 83 additions & 0 deletions samples/snippets/create_batch_prediction_job_sample_test.py
@@ -0,0 +1,83 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from uuid import uuid4

from google.cloud import aiplatform
import pytest

import create_batch_prediction_job_sample
import helpers

PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
LOCATION = "us-central1"
MODEL_ID = "1478306577684365312" # Permanent 50 flowers model
DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}"
GCS_SOURCE_URI = (
"gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl"
)
GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"
INSTANCES_FORMAT = "jsonl"
PREDICTIONS_FORMAT = "jsonl"


@pytest.fixture
def shared_state():
state = {}
yield state


@pytest.fixture
def job_client():
job_client = aiplatform.gapic.JobServiceClient(
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
)
return job_client


@pytest.fixture(scope="function", autouse=True)
def teardown(shared_state, job_client):
yield

job_client.cancel_batch_prediction_job(name=shared_state["batch_prediction_job_name"])

# Waiting until the job is in CANCELLED state.
helpers.wait_for_job_state(
get_job_method=job_client.get_batch_prediction_job,
name=shared_state["batch_prediction_job_name"],
)

job_client.delete_batch_prediction_job(name=shared_state["batch_prediction_job_name"])


# Creating AutoML Vision Classification batch prediction job
def test_ucaip_generated_create_batch_prediction_sample(capsys, shared_state):

model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"

create_batch_prediction_job_sample.create_batch_prediction_job_sample(
project=PROJECT_ID,
display_name=DISPLAY_NAME,
model_name=model_name,
gcs_source_uri=GCS_SOURCE_URI,
gcs_destination_output_uri_prefix=GCS_OUTPUT_URI,
instances_format=INSTANCES_FORMAT,
predictions_format=PREDICTIONS_FORMAT,
)

out, _ = capsys.readouterr()

# Save resource name of the newly created batch prediction job
shared_state["batch_prediction_job_name"] = helpers.get_name(out)

0 comments on commit 96a850f

Please sign in to comment.